Anirudh Balaraman commited on
Commit
7efd22b
·
1 Parent(s): 6a5314a

update patch visualisation

Browse files
Files changed (3) hide show
  1. src/utils.py +69 -39
  2. temp.ipynb +0 -0
  3. visualisation.ipynb +0 -0
src/utils.py CHANGED
@@ -4,7 +4,8 @@ import os
4
  import sys
5
  from pathlib import Path
6
  from typing import Any, Union
7
-
 
8
  import cv2
9
  import numpy as np
10
  import torch
@@ -173,52 +174,81 @@ def get_parent_image(temp_data_list, args: argparse.Namespace) -> np.ndarray:
173
  return dataset_image[0]["image"][0].numpy()
174
 
175
 
176
- """
177
- def visualise_patches():
178
- sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
179
- rows = len(patches_top_5)
180
- img = sample[0]
181
- coords = []
182
- rows, h, w, slices = sample.shape
183
 
184
- fig, axes = plt.subplots(nrows=rows, ncols=slices, figsize=(slices * 3, rows * 3))
185
-
186
- for i in range(rows):
187
- for j in range(slices):
188
- ax = axes[i, j]
189
-
190
- if j == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- for k in range(parent_image.shape[2]):
193
- img_temp = parent_image[:, :, k]
194
- H, W = img_temp.shape
195
- h, w = sample[i, :, :, j].shape
196
- a,b = 0, 0 # Initialize a and b
197
- bool1 = False
198
- for l in range(H - h + 1):
199
- for m in range(W - w + 1):
200
- if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]):
201
- a,b = l, m # top-left corner
202
- coords.append((a,b,k))
203
- bool1 = True
204
- break
205
- if bool1:
206
- break
207
 
208
- if bool1:
209
- break
210
 
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
- ax.imshow(parent_image[:, :, k+j], cmap='gray')
215
- rect = patches.Rectangle((b, a), args.tile_size, args.tile_size,
216
- linewidth=2, edgecolor='red', facecolor='none')
217
- ax.add_patch(rect)
218
  ax.axis('off')
219
 
220
-
 
 
 
 
 
 
 
 
 
 
221
  plt.tight_layout()
222
  plt.show()
223
- a=1
224
- """
 
4
  import sys
5
  from pathlib import Path
6
  from typing import Any, Union
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
  import cv2
10
  import numpy as np
11
  import torch
 
174
  return dataset_image[0]["image"][0].numpy()
175
 
176
 
 
 
 
 
 
 
 
177
 
178
+ def visualise_patches(coords, image, tile_size = 64, depth=3):
179
+ """
180
+ Visualize 3D image patches with their locations marked by bounding rectangles.
181
+ This function creates a grid of subplot visualizations where each row represents
182
+ a patch and each column represents a slice along the z-axis. Each patch location
183
+ is highlighted with a red rectangle on the corresponding image slice.
184
+ Args:
185
+ coords (list): List of patch coordinates, where each coordinate is a tuple/list
186
+ of (y, x, z) representing the top-left corner position of the patch.
187
+ image (ndarray): 3D image array of shape (height, width, slices) containing the
188
+ image data to visualize.
189
+ tile_size (int, optional): Size of the square patch in pixels. Defaults to 64.
190
+ depth (int, optional): Number of consecutive z-slices to display for each patch.
191
+ Defaults to 3.
192
+ Returns:
193
+ None: Displays the visualization using plt.show(). The slice id is displayed on th etop left corner of the image.
194
+ Raises:
195
+ None
196
+ Example:
197
+ >>> coords = [(10, 20, 5), (50, 60, 10)]
198
+ >>> image = np.random.rand(256, 256, 50)
199
+ >>> visualise_patches(coords, image, tile_size=64, depth=3)
200
+ """
201
 
202
+ rows, _, _, slices = (len(coords), tile_size, tile_size, depth)
203
+ fig, axes = plt.subplots(
204
+ nrows=rows,
205
+ ncols=slices,
206
+ figsize=(slices * 3, rows * 3),
207
+ squeeze=False
208
+ )
 
 
 
 
 
 
 
 
209
 
210
+ for i, x in enumerate(coords):
211
+ for j in range(slices):
212
 
213
+ ax = axes[i, j]
214
 
215
+ slice_id = x[2] + j
216
+ ax.imshow(image[:, :, slice_id], cmap='gray')
217
+
218
+ rect = patches.Rectangle(
219
+ (x[1], x[0]),
220
+ tile_size,
221
+ tile_size,
222
+ linewidth=2,
223
+ edgecolor='red',
224
+ facecolor='none'
225
+ )
226
+ ax.add_patch(rect)
227
 
228
+ # ---- slice ID text (every image) ----
229
+ ax.text(
230
+ 0.02, 0.98,
231
+ f"z={slice_id}",
232
+ transform=ax.transAxes,
233
+ fontsize=10,
234
+ color='white',
235
+ va='top',
236
+ ha='left',
237
+ bbox=dict(facecolor='black', alpha=0.4, pad=2)
238
+ )
239
 
 
 
 
 
240
  ax.axis('off')
241
 
242
+ # Row label
243
+ axes[i, 0].text(
244
+ -0.08, 0.5,
245
+ f"Patch {i+1}",
246
+ transform=axes[i, 0].transAxes,
247
+ fontsize=12,
248
+ va='center',
249
+ ha='right'
250
+ )
251
+
252
+ plt.subplots_adjust(left=0.06)
253
  plt.tight_layout()
254
  plt.show()
 
 
temp.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
visualisation.ipynb ADDED
The diff for this file is too large to render. See raw diff