Spaces:
Running on Zero
Running on Zero
File size: 7,011 Bytes
e340a84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import torch
from typing import Dict, Any, List
from longstream.streaming.stream_session import StreamSession
_SEQUENCE_OUTPUT_KEYS = {
"pose_enc",
"rel_pose_enc",
"world_points",
"world_points_conf",
"depth",
"depth_conf",
}
_SCALAR_OUTPUT_KEYS = {
"predicted_scale_factor",
"global_scale",
}
def _refresh_intervals(refresh: int) -> int:
refresh = int(refresh)
if refresh < 2:
raise ValueError("refresh must be >= 2")
return refresh - 1
def _model_device(model) -> torch.device:
return next(model.parameters()).device
def _move_scalar_to_cpu(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu()
return value
def _append_batch_output(
stitched_tensors: Dict[str, List[torch.Tensor]],
stitched_scalars: Dict[str, Any],
output: Dict[str, Any],
actual_frames: int,
slice_start: int,
) -> None:
for key in _SEQUENCE_OUTPUT_KEYS:
value = output.get(key)
if not isinstance(value, torch.Tensor):
continue
if value.ndim < 2 or value.shape[1] != actual_frames:
continue
stitched_tensors.setdefault(key, []).append(
value[:, slice_start:].detach().cpu()
)
for key in _SCALAR_OUTPUT_KEYS:
if key in output:
stitched_scalars[key] = _move_scalar_to_cpu(output[key])
def _finalize_stitched_batches(
stitched_tensors: Dict[str, List[torch.Tensor]],
stitched_scalars: Dict[str, Any],
) -> Dict[str, Any]:
stitched_output: Dict[str, Any] = {}
for key, chunks in stitched_tensors.items():
if not chunks:
continue
stitched_output[key] = (
chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1)
)
stitched_output.update(stitched_scalars)
return stitched_output
def run_batch_refresh(
model,
images,
is_keyframe,
keyframe_indices,
mode: str,
keyframe_stride: int,
refresh: int,
rel_pose_cfg,
):
B, S = images.shape[:2]
device = _model_device(model)
refresh_intervals = _refresh_intervals(refresh)
frames_per_batch = refresh_intervals * keyframe_stride + 1
step_frames = refresh_intervals * keyframe_stride
stitched_tensors: Dict[str, List[torch.Tensor]] = {}
stitched_scalars: Dict[str, Any] = {}
num_batches = (S + step_frames - 1) // step_frames
for batch_idx in range(num_batches):
start_frame = batch_idx * step_frames
end_frame = min(start_frame + frames_per_batch, S)
batch_images = images[:, start_frame:end_frame].to(device, non_blocking=True)
batch_is_keyframe = (
is_keyframe[:, start_frame:end_frame].clone()
if is_keyframe is not None
else None
)
batch_keyframe_indices = (
keyframe_indices[:, start_frame:end_frame].clone()
if keyframe_indices is not None
else None
)
if batch_idx > 0 and batch_is_keyframe is not None:
batch_is_keyframe[:, 0] = True
if batch_keyframe_indices is not None:
batch_keyframe_indices[:, 0] = start_frame
if batch_keyframe_indices is not None:
batch_keyframe_indices = batch_keyframe_indices - start_frame
batch_keyframe_indices = torch.clamp(
batch_keyframe_indices, 0, end_frame - start_frame - 1
)
batch_rel_pose_inputs = None
if rel_pose_cfg is not None and batch_is_keyframe is not None:
batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
if batch_keyframe_indices is not None:
batch_keyframe_indices = batch_keyframe_indices.to(
device, non_blocking=True
)
batch_rel_pose_inputs = {
"is_keyframe": batch_is_keyframe,
"keyframe_indices": batch_keyframe_indices,
"num_iterations": rel_pose_cfg.get("num_iterations", 4),
}
elif batch_is_keyframe is not None:
batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
batch_output = model(
images=batch_images,
mode=mode,
rel_pose_inputs=batch_rel_pose_inputs,
is_keyframe=batch_is_keyframe,
)
_append_batch_output(
stitched_tensors,
stitched_scalars,
batch_output,
actual_frames=end_frame - start_frame,
slice_start=0 if batch_idx == 0 else 1,
)
del batch_output
del batch_images
del batch_is_keyframe
del batch_keyframe_indices
return _finalize_stitched_batches(stitched_tensors, stitched_scalars)
def run_streaming_refresh(
model,
images,
is_keyframe,
keyframe_indices,
mode: str,
window_size: int,
refresh: int,
rel_pose_cfg,
):
B, S = images.shape[:2]
device = _model_device(model)
refresh_intervals = _refresh_intervals(refresh)
session = StreamSession(model, mode=mode, window_size=window_size)
keyframe_count = 0
segment_start = 0
for s in range(S):
frame_images = images[:, s : s + 1].to(device, non_blocking=True)
is_keyframe_s = (
is_keyframe[:, s : s + 1].to(device, non_blocking=True)
if is_keyframe is not None
else None
)
if keyframe_indices is not None:
keyframe_indices_s = keyframe_indices[:, s : s + 1].clone() - segment_start
keyframe_indices_s = torch.clamp(keyframe_indices_s, min=0)
keyframe_indices_s = keyframe_indices_s.to(device, non_blocking=True)
else:
keyframe_indices_s = None
session.forward_stream(
frame_images,
is_keyframe=is_keyframe_s,
keyframe_indices=keyframe_indices_s,
record=True,
)
if is_keyframe_s is None or not bool(is_keyframe_s.item()) or s <= 0:
del frame_images
if is_keyframe_s is not None:
del is_keyframe_s
if keyframe_indices_s is not None:
del keyframe_indices_s
continue
keyframe_count += 1
if keyframe_count % refresh_intervals == 0:
session.clear_cache_only()
segment_start = s
if keyframe_indices_s is not None:
keyframe_indices_self = torch.zeros_like(keyframe_indices_s)
else:
keyframe_indices_self = None
session.forward_stream(
frame_images,
is_keyframe=is_keyframe_s,
keyframe_indices=keyframe_indices_self,
record=False,
)
del frame_images
if is_keyframe_s is not None:
del is_keyframe_s
if keyframe_indices_s is not None:
del keyframe_indices_s
return session.get_all_predictions()
|