update the processor to the working version for videomme
Browse files- chunk_utils.py +25 -7
chunk_utils.py
CHANGED
|
@@ -74,7 +74,7 @@ def _visual_token_cums(
|
|
| 74 |
|
| 75 |
else:
|
| 76 |
if not in_video:
|
| 77 |
-
cums.append(ChunkCum(cum=
|
| 78 |
else:
|
| 79 |
cum += 1
|
| 80 |
token_idx += 1
|
|
@@ -115,6 +115,24 @@ class Chunk:
|
|
| 115 |
video_grid_thws: list[tuple[int, int, int]]
|
| 116 |
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def chunk_tokens(
|
| 119 |
max_chunk_size: int,
|
| 120 |
input_ids: torch.Tensor | np.ndarray,
|
|
@@ -147,22 +165,22 @@ def chunk_tokens(
|
|
| 147 |
current_video_grid_thws: list[tuple[int, int, int]] = []
|
| 148 |
|
| 149 |
for cum in sequence_cums:
|
| 150 |
-
if cum.image_grid_thw is not None:
|
| 151 |
-
current_image_grid_thws.append(cum.image_grid_thw)
|
| 152 |
-
if cum.video_grid_thw is not None:
|
| 153 |
-
current_video_grid_thws.append(cum.video_grid_thw)
|
| 154 |
if current_chunk_size + cum.cum > max_chunk_size:
|
| 155 |
chunks.append(Chunk(
|
| 156 |
start=current_chunk_start,
|
| 157 |
end=current_chunk_start + current_chunk_size,
|
| 158 |
image_grid_thws=current_image_grid_thws,
|
| 159 |
-
video_grid_thws=current_video_grid_thws
|
| 160 |
))
|
| 161 |
current_chunk_start += current_chunk_size
|
| 162 |
current_chunk_size = 0
|
| 163 |
current_image_grid_thws = []
|
| 164 |
current_video_grid_thws = []
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
current_chunk_size += cum.cum
|
| 167 |
|
| 168 |
if current_chunk_size > 0:
|
|
@@ -170,7 +188,7 @@ def chunk_tokens(
|
|
| 170 |
start=current_chunk_start,
|
| 171 |
end=current_chunk_start + current_chunk_size,
|
| 172 |
image_grid_thws=current_image_grid_thws,
|
| 173 |
-
video_grid_thws=current_video_grid_thws,
|
| 174 |
))
|
| 175 |
|
| 176 |
chunked_cums.append(chunks)
|
|
|
|
| 74 |
|
| 75 |
else:
|
| 76 |
if not in_video:
|
| 77 |
+
cums.append(ChunkCum(cum=1, image_grid_thw=None, video_grid_thw=None))
|
| 78 |
else:
|
| 79 |
cum += 1
|
| 80 |
token_idx += 1
|
|
|
|
| 115 |
video_grid_thws: list[tuple[int, int, int]]
|
| 116 |
|
| 117 |
|
| 118 |
+
def _merge_video_grid_thws(
|
| 119 |
+
thws: list[tuple[int, int, int]],
|
| 120 |
+
) -> list[tuple[int, int, int]]:
|
| 121 |
+
"""Merge consecutive video grid_thws that share the same (h, w)."""
|
| 122 |
+
if not thws:
|
| 123 |
+
return thws
|
| 124 |
+
merged: list[tuple[int, int, int]] = []
|
| 125 |
+
cur_t, cur_h, cur_w = thws[0]
|
| 126 |
+
for t, h, w in thws[1:]:
|
| 127 |
+
if h == cur_h and w == cur_w:
|
| 128 |
+
cur_t += t
|
| 129 |
+
else:
|
| 130 |
+
merged.append((cur_t, cur_h, cur_w))
|
| 131 |
+
cur_t, cur_h, cur_w = t, h, w
|
| 132 |
+
merged.append((cur_t, cur_h, cur_w))
|
| 133 |
+
return merged
|
| 134 |
+
|
| 135 |
+
|
| 136 |
def chunk_tokens(
|
| 137 |
max_chunk_size: int,
|
| 138 |
input_ids: torch.Tensor | np.ndarray,
|
|
|
|
| 165 |
current_video_grid_thws: list[tuple[int, int, int]] = []
|
| 166 |
|
| 167 |
for cum in sequence_cums:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
if current_chunk_size + cum.cum > max_chunk_size:
|
| 169 |
chunks.append(Chunk(
|
| 170 |
start=current_chunk_start,
|
| 171 |
end=current_chunk_start + current_chunk_size,
|
| 172 |
image_grid_thws=current_image_grid_thws,
|
| 173 |
+
video_grid_thws=_merge_video_grid_thws(current_video_grid_thws),
|
| 174 |
))
|
| 175 |
current_chunk_start += current_chunk_size
|
| 176 |
current_chunk_size = 0
|
| 177 |
current_image_grid_thws = []
|
| 178 |
current_video_grid_thws = []
|
| 179 |
|
| 180 |
+
if cum.image_grid_thw is not None:
|
| 181 |
+
current_image_grid_thws.append(cum.image_grid_thw)
|
| 182 |
+
if cum.video_grid_thw is not None:
|
| 183 |
+
current_video_grid_thws.append(cum.video_grid_thw)
|
| 184 |
current_chunk_size += cum.cum
|
| 185 |
|
| 186 |
if current_chunk_size > 0:
|
|
|
|
| 188 |
start=current_chunk_start,
|
| 189 |
end=current_chunk_start + current_chunk_size,
|
| 190 |
image_grid_thws=current_image_grid_thws,
|
| 191 |
+
video_grid_thws=_merge_video_grid_thws(current_video_grid_thws),
|
| 192 |
))
|
| 193 |
|
| 194 |
chunked_cums.append(chunks)
|