Cc
init
e340a84
import torch
from typing import Dict, Any, List
from longstream.streaming.stream_session import StreamSession
_SEQUENCE_OUTPUT_KEYS = {
"pose_enc",
"rel_pose_enc",
"world_points",
"world_points_conf",
"depth",
"depth_conf",
}
_SCALAR_OUTPUT_KEYS = {
"predicted_scale_factor",
"global_scale",
}
def _refresh_intervals(refresh: int) -> int:
refresh = int(refresh)
if refresh < 2:
raise ValueError("refresh must be >= 2")
return refresh - 1
def _model_device(model) -> torch.device:
return next(model.parameters()).device
def _move_scalar_to_cpu(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu()
return value
def _append_batch_output(
stitched_tensors: Dict[str, List[torch.Tensor]],
stitched_scalars: Dict[str, Any],
output: Dict[str, Any],
actual_frames: int,
slice_start: int,
) -> None:
for key in _SEQUENCE_OUTPUT_KEYS:
value = output.get(key)
if not isinstance(value, torch.Tensor):
continue
if value.ndim < 2 or value.shape[1] != actual_frames:
continue
stitched_tensors.setdefault(key, []).append(
value[:, slice_start:].detach().cpu()
)
for key in _SCALAR_OUTPUT_KEYS:
if key in output:
stitched_scalars[key] = _move_scalar_to_cpu(output[key])
def _finalize_stitched_batches(
stitched_tensors: Dict[str, List[torch.Tensor]],
stitched_scalars: Dict[str, Any],
) -> Dict[str, Any]:
stitched_output: Dict[str, Any] = {}
for key, chunks in stitched_tensors.items():
if not chunks:
continue
stitched_output[key] = (
chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1)
)
stitched_output.update(stitched_scalars)
return stitched_output
def run_batch_refresh(
model,
images,
is_keyframe,
keyframe_indices,
mode: str,
keyframe_stride: int,
refresh: int,
rel_pose_cfg,
):
B, S = images.shape[:2]
device = _model_device(model)
refresh_intervals = _refresh_intervals(refresh)
frames_per_batch = refresh_intervals * keyframe_stride + 1
step_frames = refresh_intervals * keyframe_stride
stitched_tensors: Dict[str, List[torch.Tensor]] = {}
stitched_scalars: Dict[str, Any] = {}
num_batches = (S + step_frames - 1) // step_frames
for batch_idx in range(num_batches):
start_frame = batch_idx * step_frames
end_frame = min(start_frame + frames_per_batch, S)
batch_images = images[:, start_frame:end_frame].to(device, non_blocking=True)
batch_is_keyframe = (
is_keyframe[:, start_frame:end_frame].clone()
if is_keyframe is not None
else None
)
batch_keyframe_indices = (
keyframe_indices[:, start_frame:end_frame].clone()
if keyframe_indices is not None
else None
)
if batch_idx > 0 and batch_is_keyframe is not None:
batch_is_keyframe[:, 0] = True
if batch_keyframe_indices is not None:
batch_keyframe_indices[:, 0] = start_frame
if batch_keyframe_indices is not None:
batch_keyframe_indices = batch_keyframe_indices - start_frame
batch_keyframe_indices = torch.clamp(
batch_keyframe_indices, 0, end_frame - start_frame - 1
)
batch_rel_pose_inputs = None
if rel_pose_cfg is not None and batch_is_keyframe is not None:
batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
if batch_keyframe_indices is not None:
batch_keyframe_indices = batch_keyframe_indices.to(
device, non_blocking=True
)
batch_rel_pose_inputs = {
"is_keyframe": batch_is_keyframe,
"keyframe_indices": batch_keyframe_indices,
"num_iterations": rel_pose_cfg.get("num_iterations", 4),
}
elif batch_is_keyframe is not None:
batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
batch_output = model(
images=batch_images,
mode=mode,
rel_pose_inputs=batch_rel_pose_inputs,
is_keyframe=batch_is_keyframe,
)
_append_batch_output(
stitched_tensors,
stitched_scalars,
batch_output,
actual_frames=end_frame - start_frame,
slice_start=0 if batch_idx == 0 else 1,
)
del batch_output
del batch_images
del batch_is_keyframe
del batch_keyframe_indices
return _finalize_stitched_batches(stitched_tensors, stitched_scalars)
def run_streaming_refresh(
model,
images,
is_keyframe,
keyframe_indices,
mode: str,
window_size: int,
refresh: int,
rel_pose_cfg,
):
B, S = images.shape[:2]
device = _model_device(model)
refresh_intervals = _refresh_intervals(refresh)
session = StreamSession(model, mode=mode, window_size=window_size)
keyframe_count = 0
segment_start = 0
for s in range(S):
frame_images = images[:, s : s + 1].to(device, non_blocking=True)
is_keyframe_s = (
is_keyframe[:, s : s + 1].to(device, non_blocking=True)
if is_keyframe is not None
else None
)
if keyframe_indices is not None:
keyframe_indices_s = keyframe_indices[:, s : s + 1].clone() - segment_start
keyframe_indices_s = torch.clamp(keyframe_indices_s, min=0)
keyframe_indices_s = keyframe_indices_s.to(device, non_blocking=True)
else:
keyframe_indices_s = None
session.forward_stream(
frame_images,
is_keyframe=is_keyframe_s,
keyframe_indices=keyframe_indices_s,
record=True,
)
if is_keyframe_s is None or not bool(is_keyframe_s.item()) or s <= 0:
del frame_images
if is_keyframe_s is not None:
del is_keyframe_s
if keyframe_indices_s is not None:
del keyframe_indices_s
continue
keyframe_count += 1
if keyframe_count % refresh_intervals == 0:
session.clear_cache_only()
segment_start = s
if keyframe_indices_s is not None:
keyframe_indices_self = torch.zeros_like(keyframe_indices_s)
else:
keyframe_indices_self = None
session.forward_stream(
frame_images,
is_keyframe=is_keyframe_s,
keyframe_indices=keyframe_indices_self,
record=False,
)
del frame_images
if is_keyframe_s is not None:
del is_keyframe_s
if keyframe_indices_s is not None:
del keyframe_indices_s
return session.get_all_predictions()