prithivMLmods commited on
Commit
a0d7bbe
·
verified ·
1 Parent(s): 5ea39f8

update [kernels:flash-attn2] (cleaned) ✅

Browse files
Files changed (1) hide show
  1. app.py +169 -32
app.py CHANGED
@@ -25,8 +25,6 @@ from transformers import (
25
  from gradio.themes import Soft
26
  from gradio.themes.utils import colors, fonts, sizes
27
 
28
- # --- Theme and CSS Definition ---
29
-
30
  colors.steel_blue = colors.Color(
31
  name="steel_blue",
32
  c50="#EBF3F8",
@@ -34,7 +32,7 @@ colors.steel_blue = colors.Color(
34
  c200="#A8CCE1",
35
  c300="#7DB3D2",
36
  c400="#529AC3",
37
- c500="#4682B4", # SteelBlue base color
38
  c600="#3E72A0",
39
  c700="#36638C",
40
  c800="#2E5378",
@@ -89,64 +87,174 @@ class SteelBlueTheme(Soft):
89
 
90
  steel_blue_theme = SteelBlueTheme()
91
 
92
- # Constants for text generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  MAX_MAX_NEW_TOKENS = 4096
94
  DEFAULT_MAX_NEW_TOKENS = 1024
95
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
96
 
97
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98
 
99
- # Load DeepCaption-VLA-7B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  MODEL_ID_N = "prithivMLmods/DeepCaption-VLA-7B"
101
  processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
102
  model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained(
103
  MODEL_ID_N,
104
- attn_implementation="flash_attention_2",
105
  trust_remote_code=True,
106
  torch_dtype=torch.float16
107
  ).to(device).eval()
108
 
109
- # Load SkyCaptioner-V1
110
  MODEL_ID_M = "Skywork/SkyCaptioner-V1"
111
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
112
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
113
  MODEL_ID_M,
114
- attn_implementation="flash_attention_2",
115
  trust_remote_code=True,
116
  torch_dtype=torch.float16
117
  ).to(device).eval()
118
 
119
- # Load Space Thinker
120
  MODEL_ID_Z = "remyxai/SpaceThinker-Qwen2.5VL-3B"
121
  processor_z = AutoProcessor.from_pretrained(MODEL_ID_Z, trust_remote_code=True)
122
  model_z = Qwen2_5_VLForConditionalGeneration.from_pretrained(
123
  MODEL_ID_Z,
124
- attn_implementation="flash_attention_2",
125
  trust_remote_code=True,
126
  torch_dtype=torch.float16
127
  ).to(device).eval()
128
 
129
- # Load coreOCR-7B-050325-preview
130
  MODEL_ID_K = "prithivMLmods/coreOCR-7B-050325-preview"
131
  processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
132
  model_k = Qwen2VLForConditionalGeneration.from_pretrained(
133
  MODEL_ID_K,
134
- attn_implementation="flash_attention_2",
135
  trust_remote_code=True,
136
  torch_dtype=torch.float16
137
  ).to(device).eval()
138
 
139
- # Load remyxai/SpaceOm
140
  MODEL_ID_Y = "remyxai/SpaceOm"
141
  processor_y = AutoProcessor.from_pretrained(MODEL_ID_Y, trust_remote_code=True)
142
  model_y = Qwen2_5_VLForConditionalGeneration.from_pretrained(
143
  MODEL_ID_Y,
144
- attn_implementation="flash_attention_2",
145
  trust_remote_code=True,
146
  torch_dtype=torch.float16
147
  ).to(device).eval()
148
 
149
- # Video sampling
150
  def downsample_video(video_path):
151
  """
152
  Downsamples the video to evenly spaced frames.
@@ -168,13 +276,32 @@ def downsample_video(video_path):
168
  vidcap.release()
169
  return frames
170
 
171
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def generate_image(model_name: str, text: str, image: Image.Image,
173
  max_new_tokens: int = 1024,
174
  temperature: float = 0.6,
175
  top_p: float = 0.9,
176
  top_k: int = 50,
177
- repetition_penalty: float = 1.2):
 
178
  """
179
  Generates responses using the selected model for image input.
180
  Yields raw text and Markdown-formatted text.
@@ -224,13 +351,14 @@ def generate_image(model_name: str, text: str, image: Image.Image,
224
  time.sleep(0.01)
225
  yield buffer, buffer
226
 
227
- @spaces.GPU
228
  def generate_video(model_name: str, text: str, video_path: str,
229
  max_new_tokens: int = 1024,
230
  temperature: float = 0.6,
231
  top_p: float = 0.9,
232
  top_k: int = 50,
233
- repetition_penalty: float = 1.2):
 
234
  """
235
  Generates responses using the selected model for video input.
236
  Yields raw text and Markdown-formatted text.
@@ -291,7 +419,6 @@ def generate_video(model_name: str, text: str, video_path: str,
291
  time.sleep(0.01)
292
  yield buffer, buffer
293
 
294
- # Define examples for image and video inference
295
  image_examples = [
296
  ["type out the messy hand-writing as accurately as you can.", "images/1.jpg"],
297
  ["count the number of birds and explain the scene in detail.", "images/2.jpeg"],
@@ -305,16 +432,6 @@ video_examples = [
305
  ["explain the advertisement in detail.", "videos/2.mp4"]
306
  ]
307
 
308
- css = """
309
- #main-title h1 {
310
- font-size: 2.3em !important;
311
- }
312
- #output-title h2 {
313
- font-size: 2.1em !important;
314
- }
315
- """
316
-
317
- # Create the Gradio Interface
318
  with gr.Blocks() as demo:
319
  gr.Markdown("# **VisionScope R2**", elem_id="main-title")
320
  with gr.Row():
@@ -346,14 +463,34 @@ with gr.Blocks() as demo:
346
  label="Select Model",
347
  value="DeepCaption-VLA-7B"
348
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  image_submit.click(
350
  fn=generate_image,
351
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
352
  outputs=[output, markdown_output]
353
  )
354
  video_submit.click(
355
  fn=generate_video,
356
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
357
  outputs=[output, markdown_output]
358
  )
359
 
 
25
  from gradio.themes import Soft
26
  from gradio.themes.utils import colors, fonts, sizes
27
 
 
 
28
  colors.steel_blue = colors.Color(
29
  name="steel_blue",
30
  c50="#EBF3F8",
 
32
  c200="#A8CCE1",
33
  c300="#7DB3D2",
34
  c400="#529AC3",
35
+ c500="#4682B4",
36
  c600="#3E72A0",
37
  c700="#36638C",
38
  c800="#2E5378",
 
87
 
88
  steel_blue_theme = SteelBlueTheme()
89
 
90
+ css = """
91
+ #main-title h1 {
92
+ font-size: 2.3em !important;
93
+ }
94
+ #output-title h2 {
95
+ font-size: 2.2em !important;
96
+ }
97
+
98
+ /* RadioAnimated Styles */
99
+ .ra-wrap{ width: fit-content; }
100
+ .ra-inner{
101
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
102
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
103
+ }
104
+ .ra-input{ display: none; }
105
+ .ra-label{
106
+ position: relative; z-index: 2; padding: 8px 16px;
107
+ font-family: inherit; font-size: 14px; font-weight: 600;
108
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
109
+ }
110
+ .ra-highlight{
111
+ position: absolute; z-index: 1; top: 6px; left: 6px;
112
+ height: calc(100% - 12px); border-radius: 9999px;
113
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
114
+ transition: transform 0.2s, width 0.2s;
115
+ }
116
+ .ra-input:checked + .ra-label{ color: black; }
117
+
118
+ /* Dark mode adjustments for Radio */
119
+ .dark .ra-inner { background: var(--neutral-800); }
120
+ .dark .ra-label { color: var(--neutral-400); }
121
+ .dark .ra-highlight { background: var(--neutral-600); }
122
+ .dark .ra-input:checked + .ra-label { color: white; }
123
+
124
+ #gpu-duration-container {
125
+ padding: 10px;
126
+ border-radius: 8px;
127
+ background: var(--background-fill-secondary);
128
+ border: 1px solid var(--border-color-primary);
129
+ margin-top: 10px;
130
+ }
131
+ """
132
+
133
  MAX_MAX_NEW_TOKENS = 4096
134
  DEFAULT_MAX_NEW_TOKENS = 1024
135
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
136
 
137
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
138
 
139
+ class RadioAnimated(gr.HTML):
140
+ def __init__(self, choices, value=None, **kwargs):
141
+ if not choices or len(choices) < 2:
142
+ raise ValueError("RadioAnimated requires at least 2 choices.")
143
+ if value is None:
144
+ value = choices[0]
145
+
146
+ uid = uuid.uuid4().hex[:8]
147
+ group_name = f"ra-{uid}"
148
+
149
+ inputs_html = "\n".join(
150
+ f"""
151
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
152
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
153
+ """
154
+ for i, c in enumerate(choices)
155
+ )
156
+
157
+ html_template = f"""
158
+ <div class="ra-wrap" data-ra="{uid}">
159
+ <div class="ra-inner">
160
+ <div class="ra-highlight"></div>
161
+ {inputs_html}
162
+ </div>
163
+ </div>
164
+ """
165
+
166
+ js_on_load = r"""
167
+ (() => {
168
+ const wrap = element.querySelector('.ra-wrap');
169
+ const inner = element.querySelector('.ra-inner');
170
+ const highlight = element.querySelector('.ra-highlight');
171
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
172
+
173
+ if (!inputs.length) return;
174
+
175
+ const choices = inputs.map(i => i.value);
176
+
177
+ function setHighlightByIndex(idx) {
178
+ const n = choices.length;
179
+ const pct = 100 / n;
180
+ highlight.style.width = `calc(${pct}% - 6px)`;
181
+ highlight.style.transform = `translateX(${idx * 100}%)`;
182
+ }
183
+
184
+ function setCheckedByValue(val, shouldTrigger=false) {
185
+ const idx = Math.max(0, choices.indexOf(val));
186
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
187
+ setHighlightByIndex(idx);
188
+
189
+ props.value = choices[idx];
190
+ if (shouldTrigger) trigger('change', props.value);
191
+ }
192
+
193
+ setCheckedByValue(props.value ?? choices[0], false);
194
+
195
+ inputs.forEach((inp) => {
196
+ inp.addEventListener('change', () => {
197
+ setCheckedByValue(inp.value, true);
198
+ });
199
+ });
200
+ })();
201
+ """
202
+
203
+ super().__init__(
204
+ value=value,
205
+ html_template=html_template,
206
+ js_on_load=js_on_load,
207
+ **kwargs
208
+ )
209
+
210
+ def apply_gpu_duration(val: str):
211
+ return int(val)
212
+
213
  MODEL_ID_N = "prithivMLmods/DeepCaption-VLA-7B"
214
  processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
215
  model_n = Qwen2_5_VLForConditionalGeneration.from_pretrained(
216
  MODEL_ID_N,
217
+ attn_implementation="kernels-community/flash-attn2",
218
  trust_remote_code=True,
219
  torch_dtype=torch.float16
220
  ).to(device).eval()
221
 
 
222
  MODEL_ID_M = "Skywork/SkyCaptioner-V1"
223
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
224
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
225
  MODEL_ID_M,
226
+ attn_implementation="kernels-community/flash-attn2",
227
  trust_remote_code=True,
228
  torch_dtype=torch.float16
229
  ).to(device).eval()
230
 
 
231
  MODEL_ID_Z = "remyxai/SpaceThinker-Qwen2.5VL-3B"
232
  processor_z = AutoProcessor.from_pretrained(MODEL_ID_Z, trust_remote_code=True)
233
  model_z = Qwen2_5_VLForConditionalGeneration.from_pretrained(
234
  MODEL_ID_Z,
235
+ attn_implementation="kernels-community/flash-attn2",
236
  trust_remote_code=True,
237
  torch_dtype=torch.float16
238
  ).to(device).eval()
239
 
 
240
  MODEL_ID_K = "prithivMLmods/coreOCR-7B-050325-preview"
241
  processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
242
  model_k = Qwen2VLForConditionalGeneration.from_pretrained(
243
  MODEL_ID_K,
244
+ attn_implementation="kernels-community/flash-attn2",
245
  trust_remote_code=True,
246
  torch_dtype=torch.float16
247
  ).to(device).eval()
248
 
 
249
  MODEL_ID_Y = "remyxai/SpaceOm"
250
  processor_y = AutoProcessor.from_pretrained(MODEL_ID_Y, trust_remote_code=True)
251
  model_y = Qwen2_5_VLForConditionalGeneration.from_pretrained(
252
  MODEL_ID_Y,
253
+ attn_implementation="kernels-community/flash-attn2",
254
  trust_remote_code=True,
255
  torch_dtype=torch.float16
256
  ).to(device).eval()
257
 
 
258
  def downsample_video(video_path):
259
  """
260
  Downsamples the video to evenly spaced frames.
 
276
  vidcap.release()
277
  return frames
278
 
279
+ def calc_timeout_image(model_name: str, text: str, image: Image.Image,
280
+ max_new_tokens: int, temperature: float, top_p: float,
281
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
282
+ """Calculate GPU timeout duration for image inference."""
283
+ try:
284
+ return int(gpu_timeout)
285
+ except:
286
+ return 60
287
+
288
+ def calc_timeout_video(model_name: str, text: str, video_path: str,
289
+ max_new_tokens: int, temperature: float, top_p: float,
290
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
291
+ """Calculate GPU timeout duration for video inference."""
292
+ try:
293
+ return int(gpu_timeout)
294
+ except:
295
+ return 60
296
+
297
+ @spaces.GPU(duration=calc_timeout_image)
298
  def generate_image(model_name: str, text: str, image: Image.Image,
299
  max_new_tokens: int = 1024,
300
  temperature: float = 0.6,
301
  top_p: float = 0.9,
302
  top_k: int = 50,
303
+ repetition_penalty: float = 1.2,
304
+ gpu_timeout: int = 60):
305
  """
306
  Generates responses using the selected model for image input.
307
  Yields raw text and Markdown-formatted text.
 
351
  time.sleep(0.01)
352
  yield buffer, buffer
353
 
354
+ @spaces.GPU(duration=calc_timeout_video)
355
  def generate_video(model_name: str, text: str, video_path: str,
356
  max_new_tokens: int = 1024,
357
  temperature: float = 0.6,
358
  top_p: float = 0.9,
359
  top_k: int = 50,
360
+ repetition_penalty: float = 1.2,
361
+ gpu_timeout: int = 90):
362
  """
363
  Generates responses using the selected model for video input.
364
  Yields raw text and Markdown-formatted text.
 
419
  time.sleep(0.01)
420
  yield buffer, buffer
421
 
 
422
  image_examples = [
423
  ["type out the messy hand-writing as accurately as you can.", "images/1.jpg"],
424
  ["count the number of birds and explain the scene in detail.", "images/2.jpeg"],
 
432
  ["explain the advertisement in detail.", "videos/2.mp4"]
433
  ]
434
 
 
 
 
 
 
 
 
 
 
 
435
  with gr.Blocks() as demo:
436
  gr.Markdown("# **VisionScope R2**", elem_id="main-title")
437
  with gr.Row():
 
463
  label="Select Model",
464
  value="DeepCaption-VLA-7B"
465
  )
466
+
467
+ with gr.Row(elem_id="gpu-duration-container"):
468
+ with gr.Column():
469
+ gr.Markdown("**GPU Duration (seconds)**")
470
+ radioanimated_gpu_duration = RadioAnimated(
471
+ choices=["60", "90", "120", "180", "240", "300"],
472
+ value="60",
473
+ elem_id="radioanimated_gpu_duration"
474
+ )
475
+ gpu_duration_state = gr.Number(value=60, visible=False)
476
+
477
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota. Video tasks typically require higher values (90-180s).*")
478
+
479
+ radioanimated_gpu_duration.change(
480
+ fn=apply_gpu_duration,
481
+ inputs=radioanimated_gpu_duration,
482
+ outputs=[gpu_duration_state],
483
+ api_visibility="private"
484
+ )
485
+
486
  image_submit.click(
487
  fn=generate_image,
488
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
489
  outputs=[output, markdown_output]
490
  )
491
  video_submit.click(
492
  fn=generate_video,
493
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
494
  outputs=[output, markdown_output]
495
  )
496