| import os |
| import math |
| from skimage.metrics import peak_signal_noise_ratio as psnr |
| from skimage.metrics import mean_squared_error as mse |
| import lpips |
| from PIL import Image |
| import numpy as np |
| import torchvision.transforms as transforms |
| from eval.distance_transform_v0 import ChamferDistance2dMetric |
| import torch |
| from pytorch_msssim import ssim_matlab as ssim |
| import glob |
| import re |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| def calc_psnr(pred, gt): |
| return -10 * math.log10(((pred - gt) * (pred - gt)).mean()) |
|
|
| |
| loss_fn_alex = lpips.LPIPS(net='alex') |
| transform = transforms.Compose([ |
| transforms.PILToTensor() |
| ]) |
| cd = ChamferDistance2dMetric() |
| torch.set_printoptions(precision=4) |
|
|
| |
| psnrlisto = [] |
| ssimlisto = [] |
| ielisto = [] |
| cdlisto = [] |
|
|
| |
| eval_dir = "eval_data" |
|
|
| |
| if not os.path.exists(eval_dir): |
| print(f"Error: Directory '{eval_dir}' not found") |
| else: |
| |
| io_files = glob.glob(os.path.join(eval_dir, "io*.png")) |
| print(f"Found {len(io_files)} interpolated images in {eval_dir}") |
| |
| |
| io_files.sort(key=lambda x: int(re.findall(r'io(\d+)\.png', x)[0])) |
| |
| |
| for io_path in io_files: |
| |
| index_match = re.search(r'io(\d+)\.png', io_path) |
| if not index_match: |
| print(f"Skipping {io_path} - unable to extract index") |
| continue |
| |
| index = index_match.group(1) |
| gt_path = os.path.join(eval_dir, f"gt{index}.png") |
| |
| |
| if not os.path.exists(gt_path): |
| print(f"Warning: Ground truth image '{gt_path}' not found for '{io_path}'") |
| continue |
| |
| print(f"Processing pair {index}: {io_path} and {gt_path}") |
| |
| try: |
| |
| pio = Image.open(io_path).convert("RGB").resize((384, 192)) |
| pred1o = transform(pio).unsqueeze(0).float()/255. |
| predo = np.asarray(pio) |
| |
| |
| gi = Image.open(gt_path).convert("RGB").resize((384, 192)) |
| gt1 = transform(gi).unsqueeze(0).float()/255. |
| gt = np.asarray(gi) |
| |
| |
| try: |
| psnr_val = psnr(predo, gt) |
| psnrlisto.append(psnr_val) |
| print(f"PSNR: {psnr_val:.4f}") |
| except Exception as e: |
| print(f"Error calculating PSNR: {e}") |
| |
| try: |
| ie_val = math.sqrt(mse(predo, gt)) |
| ielisto.append(ie_val) |
| print(f"IE: {ie_val:.4f}") |
| except Exception as e: |
| print(f"Error calculating IE: {e}") |
| |
| try: |
| cd_val = cd.calc(pred1o, gt1) |
| cdlisto.append(cd_val) |
| print(f"CD: {cd_val:.4f}") |
| except Exception as e: |
| print(f"Error calculating Chamfer Distance: {e}") |
| |
| try: |
| ssim_val = ssim(pred1o, gt1) |
| ssimlisto.append(ssim_val) |
| print(f"SSIM: {ssim_val:.4f}") |
| except Exception as e: |
| print(f"Error calculating SSIM: {e}") |
| |
| except Exception as e: |
| print(f"Error processing images: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| psnr_avgo = np.average(psnrlisto) if psnrlisto else 0 |
| ie_avgo = np.average(ielisto) if ielisto else 0 |
| cd_avgo = np.average(cdlisto) if cdlisto else 0 |
| ssim_avgo = np.average(ssimlisto) if ssimlisto else 0 |
|
|
| |
| print("\nMetrics for each processed pair:") |
| for i, (psnr_val, ssim_val, cd_val, ie_val) in enumerate(zip(psnrlisto, ssimlisto, cdlisto, ielisto)): |
| print(f"Pair {i+1}:") |
| print(f" PSNR: {psnr_val:.4f}") |
| print(f" SSIM: {ssim_val:.4f}") |
| print(f" CD: {cd_val:.4f}") |
| print(f" IE: {ie_val:.4f}") |
|
|
| |
| print("\nFinal Results (Averages):") |
| print(f"Number of image pairs processed: {len(psnrlisto)}") |
| print("cdval: {:.4f}".format(cd_avgo)) |
| print("ssim: {:.4f}".format(ssim_avgo)) |
| print("psnr: {:.4f}".format(psnr_avgo)) |
| print("ie: {:.4f}".format(ie_avgo)) |
|
|
| |
| results_file = "eval_results.txt" |
| with open(results_file, "w") as f: |
| f.write("Evaluation Results\n") |
| f.write(f"Number of image pairs processed: {len(psnrlisto)}\n") |
| f.write("cdval: {:.4f}\n".format(cd_avgo)) |
| f.write("ssim: {:.4f}\n".format(ssim_avgo)) |
| f.write("psnr: {:.4f}\n".format(psnr_avgo)) |
| f.write("ie: {:.4f}\n".format(ie_avgo)) |
|
|
| print(f"\nResults saved to {results_file}") |