Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| #!/usr/bin/env python3 | |
| """ | |
| Image Aligner - FastAPI Web Interface with API | |
| Dedicated with love and devotion to Alon Y., Daniel B., Denis Z., Tal S., Adi B. | |
| and the rest of the Animation Taskforce 2026 | |
| """ | |
| import io | |
| import os | |
| import base64 | |
| import subprocess | |
| import tempfile | |
| import warnings | |
| import cv2 | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import HTMLResponse, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from scipy.linalg import sqrtm, inv | |
| from skimage import exposure | |
| import uvicorn | |
| # ============== Image Alignment Core ============== | |
| def extract_features(img: np.ndarray) -> tuple: | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| n_pixels = img.shape[0] * img.shape[1] | |
| nfeatures = min(10000, max(2000, n_pixels // 200)) | |
| sift = cv2.SIFT_create(nfeatures=nfeatures, contrastThreshold=0.02, edgeThreshold=15) | |
| keypoints, descriptors = sift.detectAndCompute(gray, None) | |
| return keypoints, descriptors | |
| def match_features(desc1: np.ndarray, desc2: np.ndarray, ratio_thresh: float = 0.85) -> list: | |
| if desc1 is None or desc2 is None: | |
| return [] | |
| FLANN_INDEX_KDTREE = 1 | |
| index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) | |
| search_params = dict(checks=150) | |
| flann = cv2.FlannBasedMatcher(index_params, search_params) | |
| try: | |
| matches = flann.knnMatch(desc1, desc2, k=2) | |
| except cv2.error: | |
| return [] | |
| good_matches = [] | |
| for match_pair in matches: | |
| if len(match_pair) == 2: | |
| m, n = match_pair | |
| if m.distance < ratio_thresh * n.distance: | |
| good_matches.append(m) | |
| return good_matches | |
| def compute_homography(kp1, kp2, matches, ransac_reproj_thresh=8.0, confidence=0.9999): | |
| if len(matches) < 4: | |
| return None, None | |
| src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) | |
| dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) | |
| H, mask = cv2.findHomography( | |
| src_pts, dst_pts, | |
| method=cv2.USAC_MAGSAC, | |
| ransacReprojThreshold=ransac_reproj_thresh, | |
| maxIters=10000, | |
| confidence=confidence | |
| ) | |
| return H, mask | |
| def create_inlier_mask(keypoints, matches, inlier_mask, image_shape, radius=50): | |
| h, w = image_shape[:2] | |
| mask_img = np.zeros((h, w), dtype=np.uint8) | |
| for i, m in enumerate(matches): | |
| if inlier_mask[i]: | |
| pt = keypoints[m.trainIdx].pt | |
| cv2.circle(mask_img, (int(pt[0]), int(pt[1])), radius, 1, -1) | |
| mask = mask_img.astype(bool) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius, radius)) | |
| mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=2).astype(bool) | |
| return mask | |
| def _build_histogram_lookup(src_channel, tgt_channel, n_bins=256): | |
| src_hist, _ = np.histogram(src_channel.flatten(), bins=n_bins, range=(0, 256)) | |
| tgt_hist, _ = np.histogram(tgt_channel.flatten(), bins=n_bins, range=(0, 256)) | |
| src_cdf = np.cumsum(src_hist).astype(np.float64) | |
| src_cdf = src_cdf / (src_cdf[-1] + 1e-10) | |
| tgt_cdf = np.cumsum(tgt_hist).astype(np.float64) | |
| tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10) | |
| lookup = np.searchsorted(tgt_cdf, src_cdf).astype(np.uint8) | |
| return lookup | |
| def _build_histogram_lookup_float(src_channel, tgt_channel, n_bins=256): | |
| src_hist, _ = np.histogram(src_channel.flatten(), bins=n_bins, range=(0, 256)) | |
| tgt_hist, _ = np.histogram(tgt_channel.flatten(), bins=n_bins, range=(0, 256)) | |
| src_cdf = np.cumsum(src_hist).astype(np.float64) | |
| src_cdf = src_cdf / (src_cdf[-1] + 1e-10) | |
| tgt_cdf = np.cumsum(tgt_hist).astype(np.float64) | |
| tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10) | |
| lookup = np.searchsorted(tgt_cdf, src_cdf).astype(np.float32) | |
| return lookup | |
| def histogram_matching_lab(source, target, mask=None): | |
| source_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) | |
| target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) | |
| if mask is None: | |
| matched_lab = np.zeros_like(source_lab) | |
| for i in range(3): | |
| matched_lab[:, :, i] = exposure.match_histograms(source_lab[:, :, i], target_lab[:, :, i]) | |
| else: | |
| matched_lab = np.zeros_like(source_lab) | |
| for i in range(3): | |
| src_masked = source_lab[:, :, i][mask] | |
| tgt_masked = target_lab[:, :, i][mask] | |
| lookup = _build_histogram_lookup_float(src_masked, tgt_masked) | |
| src_channel = source_lab[:, :, i] | |
| src_floor = np.floor(src_channel).astype(np.int32) | |
| src_ceil = np.minimum(src_floor + 1, 255) | |
| src_frac = src_channel - src_floor | |
| src_floor = np.clip(src_floor, 0, 255) | |
| matched_lab[:, :, i] = (1 - src_frac) * lookup[src_floor] + src_frac * lookup[src_ceil] | |
| matched_lab = np.clip(matched_lab, 0, 255).astype(np.uint8) | |
| return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2BGR) | |
| def histogram_matching_rgb(source, target, mask=None): | |
| if mask is None: | |
| matched = np.zeros_like(source) | |
| for i in range(3): | |
| matched[:, :, i] = exposure.match_histograms(source[:, :, i], target[:, :, i]) | |
| return matched | |
| matched = np.zeros_like(source) | |
| for i in range(3): | |
| src_masked = source[:, :, i][mask] | |
| tgt_masked = target[:, :, i][mask] | |
| lookup = _build_histogram_lookup(src_masked, tgt_masked) | |
| matched[:, :, i] = lookup[source[:, :, i]] | |
| return matched | |
| def piecewise_linear_histogram_transfer(source, target, n_bins=256, mask=None): | |
| result = np.zeros_like(source, dtype=np.float32) | |
| for c in range(3): | |
| if mask is not None: | |
| src_channel = source[:, :, c][mask].astype(np.float32) | |
| tgt_channel = target[:, :, c][mask].astype(np.float32) | |
| else: | |
| src_channel = source[:, :, c].flatten().astype(np.float32) | |
| tgt_channel = target[:, :, c].flatten().astype(np.float32) | |
| src_hist, _ = np.histogram(src_channel, bins=n_bins, range=(0, 256)) | |
| tgt_hist, _ = np.histogram(tgt_channel, bins=n_bins, range=(0, 256)) | |
| src_cdf = np.cumsum(src_hist).astype(np.float64) | |
| src_cdf = src_cdf / (src_cdf[-1] + 1e-10) | |
| tgt_cdf = np.cumsum(tgt_hist).astype(np.float64) | |
| tgt_cdf = tgt_cdf / (tgt_cdf[-1] + 1e-10) | |
| lookup = np.searchsorted(tgt_cdf, src_cdf).astype(np.float32) | |
| src_img = source[:, :, c].astype(np.float32) | |
| src_floor = np.floor(src_img).astype(np.int32) | |
| src_ceil = np.minimum(src_floor + 1, n_bins - 1) | |
| src_frac = src_img - src_floor | |
| src_floor = np.clip(src_floor, 0, n_bins - 1) | |
| result[:, :, c] = (1 - src_frac) * lookup[src_floor] + src_frac * lookup[src_ceil] | |
| return np.clip(result, 0, 255).astype(np.uint8) | |
| def fast_color_transfer(source, target, mask=None): | |
| src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) | |
| tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) | |
| if mask is not None: | |
| src_stats = src_lab[mask] | |
| tgt_stats = tgt_lab[mask] | |
| else: | |
| src_stats = src_lab.reshape(-1, 3) | |
| tgt_stats = tgt_lab.reshape(-1, 3) | |
| for i in range(3): | |
| s_mean, s_std = src_stats[:, i].mean(), src_stats[:, i].std() + 1e-6 | |
| t_mean, t_std = tgt_stats[:, i].mean(), tgt_stats[:, i].std() + 1e-6 | |
| src_lab[:, :, i] = (src_lab[:, :, i] - s_mean) * (t_std / s_std) + t_mean | |
| return cv2.cvtColor(np.clip(src_lab, 0, 255).astype(np.uint8), cv2.COLOR_LAB2BGR) | |
| def full_histogram_matching(source, target, mask=None): | |
| lab_matched = histogram_matching_lab(source, target, mask) | |
| cdf_matched = piecewise_linear_histogram_transfer(source, target, mask=mask) | |
| multi_matched = histogram_matching_rgb(source, target, mask) | |
| result = (0.5 * lab_matched.astype(np.float32) + | |
| 0.3 * cdf_matched.astype(np.float32) + | |
| 0.2 * multi_matched.astype(np.float32)) | |
| return np.clip(result, 0, 255).astype(np.uint8) | |
| # ============== Post-Processing ============== | |
| # Level configs: (blur_sigma_mult, blur_sigma_min, motion_min, crf_boost) | |
| PP_LEVELS = { | |
| 0: None, # disabled | |
| 1: {'sigma_mult': 0.8, 'sigma_min': 0.0, 'motion_min': 1, 'crf_boost': 0}, | |
| 2: {'sigma_mult': 1.1, 'sigma_min': 0.0, 'motion_min': 1, 'crf_boost': 2}, | |
| 3: {'sigma_mult': 1.5, 'sigma_min': 0.3, 'motion_min': 3, 'crf_boost': 5}, | |
| } | |
| def detect_foreground_mask(aligned, target, threshold=25, min_area=500): | |
| diff = cv2.absdiff(aligned, target) | |
| diff_gray = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) | |
| _, binary = cv2.threshold(diff_gray, threshold, 255, cv2.THRESH_BINARY) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) | |
| binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=1) | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary, connectivity=8) | |
| cleaned = np.zeros_like(binary) | |
| for i in range(1, num_labels): | |
| if stats[i, cv2.CC_STAT_AREA] >= min_area: | |
| cleaned[labels == i] = 255 | |
| return cv2.GaussianBlur(cleaned.astype(np.float32) / 255.0, (31, 31), 0) | |
| def estimate_blur_level(image): | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var() | |
| if laplacian_var > 500: | |
| return 0.0 | |
| elif laplacian_var < 10: | |
| return 3.0 | |
| return max(0.0, 2.5 - np.log10(laplacian_var) * 0.9) | |
| def estimate_motion_blur(image): | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float64) | |
| f = np.fft.fft2(gray) | |
| fshift = np.fft.fftshift(f) | |
| magnitude = np.log1p(np.abs(fshift)) | |
| h, w = magnitude.shape | |
| cy, cx = h // 2, w // 2 | |
| radius = min(h, w) // 4 | |
| angles_deg = np.arange(0, 180, 5) | |
| angles_rad = np.deg2rad(angles_deg) | |
| rs = np.arange(5, radius) | |
| dx = np.cos(angles_rad) | |
| dy = np.sin(angles_rad) | |
| X_pos = (cx + np.outer(rs, dx)).astype(int) | |
| Y_pos = (cy + np.outer(rs, dy)).astype(int) | |
| X_neg = (cx - np.outer(rs, dx)).astype(int) | |
| Y_neg = (cy - np.outer(rs, dy)).astype(int) | |
| valid_pos = (X_pos >= 0) & (X_pos < w) & (Y_pos >= 0) & (Y_pos < h) | |
| valid_neg = (X_neg >= 0) & (X_neg < w) & (Y_neg >= 0) & (Y_neg < h) | |
| energy_pos = np.where(valid_pos, magnitude[np.clip(Y_pos, 0, h-1), np.clip(X_pos, 0, w-1)], 0.0) | |
| energy_neg = np.where(valid_neg, magnitude[np.clip(Y_neg, 0, h-1), np.clip(X_neg, 0, w-1)], 0.0) | |
| total_energy = energy_pos.sum(axis=0) + energy_neg.sum(axis=0) | |
| total_count = valid_pos.sum(axis=0) + valid_neg.sum(axis=0) | |
| valid_angles = total_count > 0 | |
| avg_energies = np.where(valid_angles, total_energy / (total_count + 1e-10), 0.0) | |
| if valid_angles.any(): | |
| min_idx = np.argmin(np.where(valid_angles, avg_energies, np.inf)) | |
| max_idx = np.argmax(np.where(valid_angles, avg_energies, -np.inf)) | |
| best_angle = angles_deg[min_idx] | |
| min_energy = avg_energies[min_idx] | |
| max_energy = avg_energies[max_idx] | |
| else: | |
| best_angle, min_energy, max_energy = 0.0, 0.0, 0.0 | |
| blur_angle = (best_angle + 90) % 180 | |
| anisotropy = (max_energy - min_energy) / (max_energy + 1e-6) | |
| kernel_size = 1 if anisotropy < 0.05 else max(1, int(anisotropy * 25)) | |
| return kernel_size, blur_angle | |
| def estimate_crf(image): | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float64) | |
| h, w = gray.shape | |
| laplacian = cv2.Laplacian(gray, cv2.CV_64F) | |
| hf_energy = np.mean(np.abs(laplacian)) | |
| cols_4 = np.arange(4, w - 1, 4) | |
| rows_4 = np.arange(4, h - 1, 4) | |
| block_diffs_x = np.mean(np.abs(gray[:, cols_4] - gray[:, cols_4 - 1]), axis=0) if len(cols_4) else np.array([]) | |
| block_diffs_y = np.mean(np.abs(gray[rows_4, :] - gray[rows_4 - 1, :]), axis=1) if len(rows_4) else np.array([]) | |
| block_diffs = np.concatenate([block_diffs_x, block_diffs_y]) | |
| cols_interior = np.arange(3, w - 1, 4) | |
| cols_interior = cols_interior[cols_interior % 4 != 0] | |
| interior_diffs = np.mean(np.abs(gray[:, cols_interior] - gray[:, cols_interior - 1]), axis=0) if len(cols_interior) else np.array([]) | |
| avg_block = np.median(block_diffs) if len(block_diffs) else 0 | |
| avg_interior = np.median(interior_diffs) if len(interior_diffs) else 1 | |
| blockiness = avg_block / (avg_interior + 1e-6) | |
| if hf_energy > 30: | |
| crf_from_hf = 15 | |
| elif hf_energy > 15: | |
| crf_from_hf = 23 | |
| elif hf_energy > 8: | |
| crf_from_hf = 30 | |
| else: | |
| crf_from_hf = 38 | |
| crf_from_blockiness = 18 + int((blockiness - 1.0) * 20) | |
| crf = int(0.6 * crf_from_hf + 0.4 * crf_from_blockiness) | |
| return max(0, min(51, crf)) | |
| def apply_h264_compression(image, crf=23): | |
| h, w = image.shape[:2] | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| in_path = os.path.join(tmpdir, 'in.png') | |
| out_path = os.path.join(tmpdir, 'out.mp4') | |
| dec_path = os.path.join(tmpdir, 'dec.png') | |
| cv2.imwrite(in_path, image) | |
| subprocess.run([ | |
| 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error', | |
| '-i', in_path, | |
| '-c:v', 'libx264', '-crf', str(crf), | |
| '-pix_fmt', 'yuv420p', '-frames:v', '1', | |
| out_path | |
| ], check=True, capture_output=True) | |
| subprocess.run([ | |
| 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error', | |
| '-i', out_path, '-frames:v', '1', dec_path | |
| ], check=True, capture_output=True) | |
| result = cv2.imread(dec_path) | |
| if result.shape[:2] != (h, w): | |
| result = cv2.resize(result, (w, h), interpolation=cv2.INTER_LANCZOS4) | |
| return result | |
| def apply_motion_blur(image, kernel_size=11, angle=0.0): | |
| if kernel_size <= 1: | |
| return image.copy() | |
| kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32) | |
| center = kernel_size // 2 | |
| angle_rad = np.deg2rad(angle) | |
| dx, dy = np.cos(angle_rad), np.sin(angle_rad) | |
| for i in range(kernel_size): | |
| t = i - center | |
| x, y = int(round(center + t * dx)), int(round(center + t * dy)) | |
| if 0 <= x < kernel_size and 0 <= y < kernel_size: | |
| kernel[y, x] = 1.0 | |
| kernel /= kernel.sum() + 1e-8 | |
| return cv2.filter2D(image, -1, kernel) | |
| def apply_gaussian_blur(image, sigma): | |
| if sigma <= 0: | |
| return image.copy() | |
| ksize = int(np.ceil(sigma * 6)) | 1 | |
| return cv2.GaussianBlur(image, (ksize, ksize), sigma) | |
| def postprocess_foreground(aligned, target, level=2): | |
| if level <= 0 or level not in PP_LEVELS: | |
| return aligned | |
| cfg = PP_LEVELS[level] | |
| # Estimate target degradation | |
| blur_sigma = estimate_blur_level(target) | |
| motion_kernel, motion_angle = estimate_motion_blur(target) | |
| crf = estimate_crf(target) | |
| # Detect foreground | |
| fg_mask = detect_foreground_mask(aligned, target) | |
| if np.mean(fg_mask) < 0.001: | |
| return aligned | |
| degraded = aligned.copy() | |
| # 1. Gaussian blur | |
| applied_sigma = max(blur_sigma * cfg['sigma_mult'], cfg['sigma_min']) | |
| if applied_sigma > 0: | |
| degraded = apply_gaussian_blur(degraded, applied_sigma) | |
| # 2. Motion blur | |
| applied_motion = max(motion_kernel, cfg['motion_min']) | |
| if applied_motion > 1: | |
| degraded = apply_motion_blur(degraded, applied_motion, motion_angle) | |
| # 3. H.264 CRF compression | |
| applied_crf = min(crf + cfg['crf_boost'], 51) | |
| try: | |
| degraded = apply_h264_compression(degraded, applied_crf) | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| # Fallback to JPEG | |
| jpeg_q = max(5, 95 - applied_crf * 2) | |
| _, encoded = cv2.imencode('.jpg', degraded, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_q]) | |
| degraded = cv2.imdecode(encoded, cv2.IMREAD_COLOR) | |
| # Blend into foreground only | |
| mask_3ch = fg_mask[:, :, np.newaxis] | |
| result = (degraded.astype(np.float32) * mask_3ch + | |
| aligned.astype(np.float32) * (1.0 - mask_3ch)) | |
| return np.clip(result, 0, 255).astype(np.uint8) | |
| # ============== Paste-back unedited regions ============== | |
| def detect_unedited_mask(aligned, target, threshold=45, min_edit_area=2000, | |
| safety_radius=8, blur_size=31): | |
| diff = cv2.absdiff(aligned, target) | |
| diff_gray = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) | |
| _, edited_binary = cv2.threshold(diff_gray, threshold, 255, cv2.THRESH_BINARY) | |
| grow_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| edited_binary = cv2.dilate(edited_binary, grow_kernel, iterations=1) | |
| close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31)) | |
| edited_binary = cv2.morphologyEx(edited_binary, cv2.MORPH_CLOSE, close_kernel, iterations=1) | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edited_binary, connectivity=8) | |
| cleaned = np.zeros_like(edited_binary) | |
| for i in range(1, num_labels): | |
| if stats[i, cv2.CC_STAT_AREA] >= min_edit_area: | |
| cleaned[labels == i] = 255 | |
| edited_binary = cleaned | |
| if safety_radius > 0: | |
| safety_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, | |
| (safety_radius * 2 + 1, safety_radius * 2 + 1)) | |
| edited_binary = cv2.dilate(edited_binary, safety_kernel, iterations=1) | |
| unedited_binary = 255 - edited_binary | |
| open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (41, 41)) | |
| unedited_binary = cv2.morphologyEx(unedited_binary, cv2.MORPH_OPEN, open_kernel, iterations=2) | |
| blur_size = blur_size | 1 | |
| soft_mask = cv2.GaussianBlur(unedited_binary.astype(np.float32) / 255.0, | |
| (blur_size, blur_size), 0) | |
| return soft_mask | |
| def paste_unedited_regions(aligned, target, mask): | |
| mask_3ch = mask[:, :, np.newaxis] | |
| result = target.astype(np.float32) * mask_3ch + aligned.astype(np.float32) * (1.0 - mask_3ch) | |
| return np.clip(result, 0, 255).astype(np.uint8) | |
| # ============== Alignment Pipeline ============== | |
| def align_image(source_img, target_img, pp_level=2, paste_back=True): | |
| target_h, target_w = target_img.shape[:2] | |
| target_size = (target_w, target_h) | |
| source_resized = cv2.resize(source_img, target_size, interpolation=cv2.INTER_LANCZOS4) | |
| naive_resized = source_resized.copy() | |
| kp_src, desc_src = extract_features(source_resized) | |
| kp_tgt, desc_tgt = extract_features(target_img) | |
| matches = match_features(desc_src, desc_tgt) | |
| color_mask = None | |
| if len(matches) >= 4: | |
| H, mask = compute_homography(kp_src, kp_tgt, matches) | |
| if H is not None and mask is not None: | |
| inlier_mask = mask.ravel() | |
| aligned = cv2.warpPerspective(source_resized, H, target_size, | |
| flags=cv2.INTER_LANCZOS4, | |
| borderMode=cv2.BORDER_REPLICATE) | |
| color_mask = create_inlier_mask(kp_tgt, matches, inlier_mask, | |
| target_img.shape, radius=50) | |
| else: | |
| aligned = source_resized | |
| else: | |
| aligned = source_resized | |
| result = fast_color_transfer(aligned, target_img, mask=color_mask) | |
| # Optionally paste back unedited regions from target | |
| pre_paste = result.copy() | |
| unedited_mask = None | |
| if paste_back: | |
| unedited_mask = detect_unedited_mask(result, target_img) | |
| result = paste_unedited_regions(result, target_img, unedited_mask) | |
| # Post-processing (only affects edited regions, then re-paste) | |
| pp_result = None | |
| if pp_level > 0: | |
| pp_result = postprocess_foreground(result, target_img, level=pp_level) | |
| if paste_back and unedited_mask is not None: | |
| pp_result = paste_unedited_regions(pp_result, target_img, unedited_mask) | |
| final = pp_result if pp_result is not None else result | |
| return final, naive_resized, result, pre_paste, unedited_mask, pp_result | |
| def compute_diff_image(img1, img2, amplify=3.0): | |
| gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY).astype(np.float32) | |
| gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY).astype(np.float32) | |
| abs_diff = np.abs(gray1 - gray2) | |
| diff_vis = np.clip(abs_diff * amplify, 0, 255).astype(np.uint8) | |
| return cv2.cvtColor(diff_vis, cv2.COLOR_GRAY2BGR) | |
| def create_visualization_panel(naive_resized, aligned, target, pre_paste=None, | |
| unedited_mask=None, postprocessed=None): | |
| h, w = target.shape[:2] | |
| label_height = 40 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.7 | |
| font_thickness = 2 | |
| font_color = (255, 255, 255) | |
| bg_color = (40, 40, 40) | |
| def add_label(img, text): | |
| label_bar = np.full((label_height, img.shape[1], 3), bg_color, dtype=np.uint8) | |
| text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0] | |
| text_x = (img.shape[1] - text_size[0]) // 2 | |
| text_y = (label_height + text_size[1]) // 2 | |
| cv2.putText(label_bar, text, (text_x, text_y), font, font_scale, font_color, font_thickness) | |
| return np.vstack([label_bar, img]) | |
| empty = np.full((h + label_height, w, 3), bg_color, dtype=np.uint8) | |
| mask_vis_bgr = None | |
| if unedited_mask is not None: | |
| mask_vis = (unedited_mask * 255).astype(np.uint8) | |
| mask_vis_bgr = cv2.cvtColor(mask_vis, cv2.COLOR_GRAY2BGR) | |
| diff_naive = compute_diff_image(naive_resized, target) | |
| diff_aligned = compute_diff_image(aligned, target) | |
| diff_pre_paste = compute_diff_image(pre_paste, target) if pre_paste is not None else None | |
| if postprocessed is not None: | |
| diff_pp = compute_diff_image(postprocessed, target) | |
| diff_aligned_vs_pp = compute_diff_image(aligned, postprocessed) | |
| row1_items = [ | |
| add_label(naive_resized, "Naive Resize"), | |
| add_label(aligned, "Aligned+Pasted"), | |
| add_label(postprocessed, "Post-processed"), | |
| add_label(target, "Target Reference"), | |
| ] | |
| row2_items = [ | |
| add_label(diff_naive, "Diff: Naive vs Target"), | |
| add_label(diff_pre_paste, "Diff: Pre-paste vs Target") if diff_pre_paste is not None else empty, | |
| add_label(diff_aligned, "Diff: Pasted vs Target"), | |
| add_label(diff_pp, "Diff: Post-proc vs Target"), | |
| ] | |
| if mask_vis_bgr is not None: | |
| row1_items.append(add_label(mask_vis_bgr, "Unedited Mask")) | |
| row2_items.append(empty) | |
| else: | |
| row1_items = [ | |
| add_label(naive_resized, "Naive Resize"), | |
| add_label(aligned, "Aligned+Pasted"), | |
| add_label(target, "Target Reference"), | |
| ] | |
| row2_items = [ | |
| add_label(diff_naive, "Diff: Naive vs Target"), | |
| add_label(diff_pre_paste, "Diff: Pre-paste vs Target") if diff_pre_paste is not None else empty, | |
| add_label(diff_aligned, "Diff: Pasted vs Target"), | |
| ] | |
| if mask_vis_bgr is not None: | |
| row1_items.append(add_label(mask_vis_bgr, "Unedited Mask")) | |
| row2_items.append(empty) | |
| row1 = np.hstack(row1_items) | |
| row2 = np.hstack(row2_items) | |
| return np.vstack([row1, row2]) | |
| # ============== FastAPI App ============== | |
| app = FastAPI(title="Image Aligner API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def decode_image(data: bytes) -> np.ndarray: | |
| img_array = np.frombuffer(data, dtype=np.uint8) | |
| img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | |
| return img | |
| def encode_image_png(img: np.ndarray) -> bytes: | |
| _, buffer = cv2.imencode('.png', img) | |
| return buffer.tobytes() | |
| async def align_api( | |
| source: UploadFile = File(..., description="Source image to align"), | |
| target: UploadFile = File(..., description="Target reference image"), | |
| pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)"), | |
| paste_back: bool = Form(True, description="Paste back unedited regions from target (default=true)") | |
| ): | |
| """ | |
| Align source image to target image. | |
| Returns the aligned image as PNG. | |
| """ | |
| try: | |
| pp_level = max(0, min(3, pp)) | |
| source_data = await source.read() | |
| target_data = await target.read() | |
| source_img = decode_image(source_data) | |
| target_img = decode_image(target_data) | |
| if source_img is None or target_img is None: | |
| raise HTTPException(status_code=400, detail="Failed to decode images") | |
| final, *_ = align_image(source_img, target_img, pp_level=pp_level, paste_back=paste_back) | |
| png_bytes = encode_image_png(final) | |
| return Response(content=png_bytes, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def align_base64_api( | |
| source: UploadFile = File(...), | |
| target: UploadFile = File(...), | |
| pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)"), | |
| paste_back: bool = Form(True, description="Paste back unedited regions from target (default=true)") | |
| ): | |
| """ | |
| Align source image to target image. | |
| Returns the aligned image as base64-encoded PNG. | |
| """ | |
| try: | |
| pp_level = max(0, min(3, pp)) | |
| source_data = await source.read() | |
| target_data = await target.read() | |
| source_img = decode_image(source_data) | |
| target_img = decode_image(target_data) | |
| if source_img is None or target_img is None: | |
| raise HTTPException(status_code=400, detail="Failed to decode images") | |
| final, *_ = align_image(source_img, target_img, pp_level=pp_level, paste_back=paste_back) | |
| png_bytes = encode_image_png(final) | |
| b64 = base64.b64encode(png_bytes).decode('utf-8') | |
| return {"image": f"data:image/png;base64,{b64}"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def align_viz_api( | |
| source: UploadFile = File(...), | |
| target: UploadFile = File(...), | |
| pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)"), | |
| paste_back: bool = Form(True, description="Paste back unedited regions from target (default=true)") | |
| ): | |
| """ | |
| Align source image to target and return visualization panel + final result. | |
| """ | |
| try: | |
| pp_level = max(0, min(3, pp)) | |
| source_data = await source.read() | |
| target_data = await target.read() | |
| source_img = decode_image(source_data) | |
| target_img = decode_image(target_data) | |
| if source_img is None or target_img is None: | |
| raise HTTPException(status_code=400, detail="Failed to decode images") | |
| final, naive_resized, pasted, pre_paste, unedited_mask, pp_result = \ | |
| align_image(source_img, target_img, pp_level=pp_level, paste_back=paste_back) | |
| panel = create_visualization_panel( | |
| naive_resized, pasted, target_img, | |
| pre_paste=pre_paste, | |
| unedited_mask=unedited_mask, | |
| postprocessed=pp_result | |
| ) | |
| panel_bytes = encode_image_png(panel) | |
| final_bytes = encode_image_png(final) | |
| panel_b64 = base64.b64encode(panel_bytes).decode('utf-8') | |
| final_b64 = base64.b64encode(final_bytes).decode('utf-8') | |
| return { | |
| "panel": f"data:image/png;base64,{panel_b64}", | |
| "image": f"data:image/png;base64,{final_b64}", | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| HTML_CONTENT = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Image Aligner</title> | |
| <style> | |
| * { margin: 0; padding: 0; box-sizing: border-box; } | |
| body { | |
| font-family: 'Segoe UI', system-ui, sans-serif; | |
| background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%); | |
| min-height: 100vh; | |
| color: #e8e8e8; | |
| padding: 2rem; | |
| } | |
| .dedication { | |
| text-align: center; | |
| padding: 2rem; | |
| background: linear-gradient(135deg, rgba(255, 121, 198, 0.15), rgba(139, 233, 253, 0.15)); | |
| border-radius: 12px; | |
| margin-bottom: 2rem; | |
| } | |
| .dedication h2 { font-size: 1.2rem; font-weight: 300; margin-bottom: 0.5rem; } | |
| .dedication .names { | |
| font-size: 1.5rem; | |
| font-weight: 700; | |
| background: linear-gradient(90deg, #ff79c6, #ffb86c, #8be9fd, #50fa7b); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .dedication .team { font-size: 1.1rem; color: #8be9fd; margin-top: 0.5rem; } | |
| .container { max-width: 1200px; margin: 0 auto; } | |
| h1 { text-align: center; margin-bottom: 0.5rem; font-weight: 300; font-size: 2.5rem; } | |
| .subtitle { text-align: center; color: #888; margin-bottom: 2rem; } | |
| .upload-grid { display: grid; grid-template-columns: 1fr 1fr; gap: 2rem; margin-bottom: 2rem; } | |
| .upload-box { | |
| background: rgba(255,255,255,0.03); | |
| border: 2px dashed rgba(255,255,255,0.2); | |
| border-radius: 12px; | |
| padding: 2rem; | |
| text-align: center; | |
| cursor: pointer; | |
| transition: all 0.3s; | |
| min-height: 250px; | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| .upload-box:hover { border-color: rgba(255,255,255,0.4); background: rgba(255,255,255,0.05); } | |
| .upload-box.has-image { padding: 1rem; } | |
| .upload-box img { max-width: 100%; max-height: 200px; border-radius: 8px; } | |
| .upload-box input { display: none; } | |
| .upload-box h3 { margin-bottom: 0.5rem; } | |
| .upload-box.source h3 { color: #8be9fd; } | |
| .upload-box.target h3 { color: #ffb86c; } | |
| .options-row { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 1.5rem; | |
| margin-bottom: 2rem; | |
| flex-wrap: wrap; | |
| } | |
| .options-row label { | |
| font-size: 0.95rem; | |
| color: #aaa; | |
| } | |
| .pp-select { | |
| background: rgba(255,255,255,0.08); | |
| color: #e8e8e8; | |
| border: 1px solid rgba(255,255,255,0.2); | |
| border-radius: 6px; | |
| padding: 0.5rem 1rem; | |
| font-size: 0.95rem; | |
| cursor: pointer; | |
| } | |
| .pp-select option { background: #1a1a2e; color: #e8e8e8; } | |
| .btn { | |
| display: block; | |
| width: 100%; | |
| max-width: 300px; | |
| margin: 0 auto 2rem; | |
| padding: 1rem 2rem; | |
| font-size: 1.1rem; | |
| font-weight: 600; | |
| border: none; | |
| border-radius: 8px; | |
| background: linear-gradient(135deg, #50fa7b, #00d9ff); | |
| color: #1a1a2e; | |
| cursor: pointer; | |
| transition: all 0.3s; | |
| } | |
| .btn:hover:not(:disabled) { transform: translateY(-2px); box-shadow: 0 10px 30px rgba(80,250,123,0.3); } | |
| .btn:disabled { opacity: 0.5; cursor: not-allowed; } | |
| .result { text-align: center; display: none; } | |
| .result.show { display: block; } | |
| .result img { max-width: 100%; border-radius: 8px; margin: 1rem 0; } | |
| .result a { | |
| display: inline-block; | |
| padding: 0.8rem 2rem; | |
| background: rgba(255,255,255,0.1); | |
| color: #fff; | |
| text-decoration: none; | |
| border-radius: 8px; | |
| margin-top: 1rem; | |
| } | |
| .loading { display: none; text-align: center; padding: 2rem; } | |
| .loading.show { display: block; } | |
| .spinner { | |
| width: 50px; height: 50px; | |
| border: 3px solid rgba(255,255,255,0.1); | |
| border-top-color: #50fa7b; | |
| border-radius: 50%; | |
| animation: spin 1s linear infinite; | |
| margin: 0 auto 1rem; | |
| } | |
| @keyframes spin { to { transform: rotate(360deg); } } | |
| .api-docs { | |
| background: rgba(255,255,255,0.03); | |
| border-radius: 12px; | |
| padding: 2rem; | |
| margin-top: 3rem; | |
| } | |
| .api-docs h2 { margin-bottom: 1rem; color: #50fa7b; } | |
| .api-docs pre { | |
| background: rgba(0,0,0,0.3); | |
| padding: 1rem; | |
| border-radius: 8px; | |
| overflow-x: auto; | |
| font-size: 0.9rem; | |
| } | |
| .api-docs code { color: #8be9fd; } | |
| @media (max-width: 768px) { .upload-grid { grid-template-columns: 1fr; } } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="dedication"> | |
| <h2>Dedicated with ♥ love and devotion to</h2> | |
| <div class="names">Alon Y., Daniel B., Denis Z., Tal S., Adi B.</div> | |
| <div class="team">and the rest of the Animation Taskforce 2026</div> | |
| </div> | |
| <h1>🎯 Image Aligner</h1> | |
| <p class="subtitle">Geometric alignment with background-aware color matching</p> | |
| <div class="upload-grid"> | |
| <div class="upload-box source" onclick="document.getElementById('sourceInput').click()"> | |
| <input type="file" id="sourceInput" accept="image/*"> | |
| <h3>📷 Source Image</h3> | |
| <p>Click to upload</p> | |
| </div> | |
| <div class="upload-box target" onclick="document.getElementById('targetInput').click()"> | |
| <input type="file" id="targetInput" accept="image/*"> | |
| <h3>🎯 Target Reference</h3> | |
| <p>Click to upload</p> | |
| </div> | |
| </div> | |
| <div class="options-row"> | |
| <label for="ppLevel">Post-processing:</label> | |
| <select id="ppLevel" class="pp-select"> | |
| <option value="0">0 - None</option> | |
| <option value="1">1 - Weak</option> | |
| <option value="2" selected>2 - Medium (default)</option> | |
| <option value="3">3 - Strong</option> | |
| </select> | |
| <label style="display:flex;align-items:center;gap:0.4rem;cursor:pointer;"> | |
| <input type="checkbox" id="pasteBack" checked style="width:18px;height:18px;cursor:pointer;"> | |
| Paste back unedited regions | |
| </label> | |
| </div> | |
| <button class="btn" id="alignBtn" disabled onclick="alignImages()">✨ Align Images</button> | |
| <div class="loading" id="loading"> | |
| <div class="spinner"></div> | |
| <p>Aligning images...</p> | |
| </div> | |
| <div class="result" id="result"> | |
| <h2>✨ Visualization</h2> | |
| <img id="panelImg" src="" style="max-width:100%"> | |
| <br> | |
| <a id="downloadLink" download="aligned.png">Download Aligned Image</a> | |
| </div> | |
| <div class="api-docs"> | |
| <h2>📡 API Usage</h2> | |
| <p>POST to <code>/api/align</code> with multipart form data:</p> | |
| <pre><code>// JavaScript (fetch) | |
| const formData = new FormData(); | |
| formData.append('source', sourceFile); | |
| formData.append('target', targetFile); | |
| formData.append('pp', '2'); // 0=none, 1=weak, 2=medium, 3=strong | |
| const response = await fetch('/api/align', { | |
| method: 'POST', | |
| body: formData | |
| }); | |
| const blob = await response.blob(); | |
| const url = URL.createObjectURL(blob); | |
| // Or use /api/align/base64 for base64 response: | |
| const response = await fetch('/api/align/base64', { | |
| method: 'POST', | |
| body: formData | |
| }); | |
| const data = await response.json(); | |
| console.log(data.image); // data:image/png;base64,...</code></pre> | |
| </div> | |
| </div> | |
| <script> | |
| let sourceFile = null; | |
| let targetFile = null; | |
| document.getElementById('sourceInput').onchange = (e) => { | |
| sourceFile = e.target.files[0]; | |
| showPreview('source', sourceFile); | |
| updateButton(); | |
| }; | |
| document.getElementById('targetInput').onchange = (e) => { | |
| targetFile = e.target.files[0]; | |
| showPreview('target', targetFile); | |
| updateButton(); | |
| }; | |
| function showPreview(type, file) { | |
| const box = document.querySelector(`.upload-box.${type}`); | |
| const reader = new FileReader(); | |
| reader.onload = (e) => { | |
| box.innerHTML = `<img src="${e.target.result}">`; | |
| box.classList.add('has-image'); | |
| }; | |
| reader.readAsDataURL(file); | |
| } | |
| function updateButton() { | |
| document.getElementById('alignBtn').disabled = !(sourceFile && targetFile); | |
| } | |
| async function alignImages() { | |
| const loading = document.getElementById('loading'); | |
| const result = document.getElementById('result'); | |
| loading.classList.add('show'); | |
| result.classList.remove('show'); | |
| try { | |
| const formData = new FormData(); | |
| formData.append('source', sourceFile); | |
| formData.append('target', targetFile); | |
| formData.append('pp', document.getElementById('ppLevel').value); | |
| formData.append('paste_back', document.getElementById('pasteBack').checked ? 'true' : 'false'); | |
| const response = await fetch('/api/align/viz', { | |
| method: 'POST', | |
| body: formData | |
| }); | |
| if (!response.ok) throw new Error('Alignment failed'); | |
| const data = await response.json(); | |
| document.getElementById('panelImg').src = data.panel; | |
| document.getElementById('downloadLink').href = data.image; | |
| result.classList.add('show'); | |
| } catch (err) { | |
| alert('Error: ' + err.message); | |
| } finally { | |
| loading.classList.remove('show'); | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| async def root(): | |
| return HTML_CONTENT | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |