Skip to content

Commit

Permalink
added squeeze logic to get_frames
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson authored Jan 30, 2024
1 parent ea67c66 commit a7eb0c6
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/roiextractors/imagingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,15 @@ def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0) -> np.nd
"""
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]
squeeze = True
assert max(frame_idxs) <= self.get_num_frames(), "'frame_idxs' exceed number of frames"
if np.all(np.diff(frame_idxs) == 0):
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1])
relative_indices = np.array(frame_idxs) - frame_idxs[0]
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)[relative_indices, ..., channel]
frames = self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)[relative_indices, ..., channel]
if squeeze:
frames = frames.squeeze()
return frames

def frame_to_time(self, frames: Union[FloatType, np.ndarray]) -> Union[FloatType, np.ndarray]:
"""Convert user-inputted frame indices to times with units of seconds.
Expand Down

0 comments on commit a7eb0c6

Please sign in to comment.