added one function as helper
Browse files- data/cell_dataset.py +30 -0
data/cell_dataset.py
CHANGED
|
@@ -186,3 +186,33 @@ def load_folder_data(folder_path, substrate=None, img_size=1024, blur_heatmap=Fa
|
|
| 186 |
threshold=threshold, return_metadata=return_metadata)
|
| 187 |
val_dataset.name = os.path.basename(folder_path)
|
| 188 |
return DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
threshold=threshold, return_metadata=return_metadata)
|
| 187 |
val_dataset.name = os.path.basename(folder_path)
|
| 188 |
return DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def collect_image_paths(folder_path, exts=None):
|
| 192 |
+
if exts is None:
|
| 193 |
+
exts = {".tif", ".tiff", ".jpg", ".jpeg", ".png"}
|
| 194 |
+
paths = []
|
| 195 |
+
for root, _, files in os.walk(os.path.normpath(folder_path)):
|
| 196 |
+
for f in files:
|
| 197 |
+
if os.path.splitext(f)[1].lower() in exts:
|
| 198 |
+
paths.append(os.path.join(root, f))
|
| 199 |
+
return sorted(paths)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class BrightfieldOnlyDataset(Dataset):
|
| 203 |
+
"""Dataset of brightfield images only (no labels), for inference."""
|
| 204 |
+
def __init__(self, folder_path, target_size=1024):
|
| 205 |
+
self.paths = collect_image_paths(folder_path)
|
| 206 |
+
self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size
|
| 207 |
+
|
| 208 |
+
def __len__(self):
|
| 209 |
+
return len(self.paths)
|
| 210 |
+
|
| 211 |
+
def __getitem__(self, i):
|
| 212 |
+
x = load_image(self.paths[i], self.target_size)
|
| 213 |
+
return torch.from_numpy(x).float().unsqueeze(0)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def load_brightfield_loader(folder_path, img_size=1024, batch_size=2):
|
| 217 |
+
ds = BrightfieldOnlyDataset(folder_path, target_size=img_size)
|
| 218 |
+
return DataLoader(ds, batch_size=batch_size, shuffle=False)
|