deepfake-server / src /data /dataset.py
DevQueen's picture
Sync from GitHub via hub-sync
1dc2504 verified
Raw
History Blame Contribute Delete
1.52 kB
from __future__ import annotations
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
class EyeSequenceDataset(Dataset):
def __init__(self, metadata_csv: str, split: str) -> None:
self.samples: List[Dict[str, str]] = []
df = pd.read_csv(metadata_csv)
df = df[df["split"] == split]
for row in df.to_dict(orient="records"):
if "npz_path" in row:
self.samples.append(
{"path": str(row["npz_path"]), "label": int(row["label"])}
)
continue
# Legacy layout from extract_eye_sequences.py
seq_dir = Path(row["sequence_dir"])
for npz in sorted(seq_dir.glob("*.npz")):
self.samples.append({"path": str(npz), "label": int(row["label"])})
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
sample = self.samples[idx]
obj = np.load(sample["path"])
frames = obj["frames"].astype(np.float32) / 255.0
ear_key = "ear" if "ear" in obj else "blink"
ear = obj[ear_key].astype(np.float32)
# T,H,W,C -> T,C,H,W
frames = np.transpose(frames, (0, 3, 1, 2))
return {
"frames": torch.tensor(frames),
"ear": torch.tensor(ear),
"blink": torch.tensor(ear),
"label": torch.tensor(sample["label"], dtype=torch.long),
}