| |
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| import numpy as np |
| from backend.keyframes.model import DSN |
| import torch.nn as nn |
| import cv2 |
| import time |
| import os |
| import srt |
| from backend.keyframes.extract_frames import extract_frames |
| from backend.utils import copy_and_rename_file, get_black_bar_coordinates, crop_image |
| import signal |
| import threading |
|
|
| |
| |
| _googlenet_model = None |
| _preprocess_pipeline = None |
|
|
| def _get_features(frames, gpu=True, batch_size=1): |
| global _googlenet_model, _preprocess_pipeline |
| |
| |
| if _googlenet_model is None: |
| print("🔄 Loading GoogLeNet model (this happens only once)...") |
| _googlenet_model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', weights='GoogLeNet_Weights.DEFAULT') |
| |
| _googlenet_model = torch.nn.Sequential(*(list(_googlenet_model.children())[:-1])) |
| _googlenet_model.eval() |
| |
| |
| _preprocess_pipeline = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| |
| |
| if gpu: |
| _googlenet_model.to('cuda') |
| print("✅ GoogLeNet model loaded successfully") |
|
|
| |
| features = [] |
|
|
| |
| for frame_path in frames: |
| |
| input_image = Image.open(frame_path) |
| input_tensor = _preprocess_pipeline(input_image) |
| input_batch = input_tensor.unsqueeze(0) |
|
|
| |
| if gpu: |
| input_batch = input_batch.to('cuda') |
|
|
| |
| with torch.no_grad(): |
| output = _googlenet_model(input_batch) |
|
|
| |
| features.append(output.squeeze().cpu().numpy()) |
|
|
| |
| features = np.array(features) |
|
|
| return features.astype(np.float32) |
|
|
| |
| _dsn_models = {} |
|
|
| def _get_probs(features, gpu=True, mode=0): |
| global _dsn_models |
| |
| |
| cache_key = f"dsn_model_{mode}_{gpu}" |
| |
| |
| if cache_key not in _dsn_models: |
| print(f"🔄 Loading DSN model {mode} (this happens only once)...") |
| |
| if mode == 1: |
| model_path = "backend/keyframes/pretrained_model/model_1.pth.tar" |
| else: |
| model_path = "backend/keyframes/pretrained_model/model_0.pth.tar" |
| |
| model = DSN(in_dim=1024, hid_dim=256, num_layers=1, cell="lstm") |
| |
| if gpu: |
| checkpoint = torch.load(model_path) |
| else: |
| checkpoint = torch.load(model_path, map_location='cpu') |
| |
| model.load_state_dict(checkpoint) |
| |
| if gpu: |
| model = nn.DataParallel(model).cuda() |
| |
| model.eval() |
| _dsn_models[cache_key] = model |
| print(f"✅ DSN model {mode} loaded successfully") |
| |
| model = _dsn_models[cache_key] |
| seq = torch.from_numpy(features).unsqueeze(0) |
| if gpu: seq = seq.cuda() |
| probs = model(seq) |
| probs = probs.data.cpu().squeeze().numpy() |
| return probs |
|
|
|
|
| |
| def generate_keyframes(video): |
| data="" |
| with open("test1.srt") as f: |
| data = f.read() |
|
|
| subs = srt.parse(data) |
| torch.cuda.empty_cache() |
| |
| |
| |
| def timeout_handler(signum, frame): |
| raise TimeoutError("Keyframe generation timed out") |
| |
| |
| if threading.current_thread() is threading.main_thread(): |
| signal.signal(signal.SIGALRM, timeout_handler) |
| signal.alarm(600) |
|
|
| |
| final_dir = os.path.join("frames", "final") |
| if not os.path.exists(final_dir): |
| os.makedirs(final_dir) |
| print(f"Created directory: {final_dir}") |
|
|
| frame_counter = 1 |
| total_subs = len(list(subs)) |
| subs = list(subs) |
| |
| print(f"🎯 Processing {total_subs} subtitle segments...") |
| |
| try: |
| |
| for i, sub in enumerate(subs, 1): |
| print(f"📝 Processing segment {i}/{total_subs}: {sub.content[:30]}...") |
| frames = [] |
| if not os.path.exists(f"frames/sub{sub.index}"): |
| os.makedirs(f"frames/sub{sub.index}") |
| |
| |
| frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"), |
| sub.start.total_seconds(), sub.end.total_seconds(), 10) |
| |
| if len(frames) > 0: |
| |
| features = _get_features(frames, gpu=False) |
| highlight_scores = _get_probs(features, gpu=False) |
| |
| |
| story_frames = _select_story_relevant_frames(frames, highlight_scores, sub) |
| |
| |
| for j, frame_idx in enumerate(story_frames): |
| if frame_counter <= 16: |
| try: |
| copy_and_rename_file(frames[frame_idx], final_dir, f"frame{frame_counter:03}.png") |
| print(f"📖 Story frame {frame_counter}: {sub.content} (score: {highlight_scores[frame_idx]:.3f})") |
| frame_counter += 1 |
| except: |
| pass |
| else: |
| |
| print(f"⚠️ No frames extracted for subtitle {sub.index}") |
| |
| |
| if frame_counter == 1: |
| print("🚨 No story-relevant frames generated – falling back to uniform extraction…") |
| try: |
| |
| video_cap = cv2.VideoCapture(video) |
| total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| step = max(total_frames // 16, 1) |
| extracted = 0 |
| frame_idx = 0 |
| while extracted < 16 and video_cap.isOpened(): |
| video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
| ret, frame = video_cap.read() |
| if not ret: |
| break |
| out_path = os.path.join(final_dir, f"frame{frame_counter:03}.png") |
| cv2.imwrite(out_path, frame) |
| frame_counter += 1 |
| extracted += 1 |
| frame_idx += step |
| video_cap.release() |
| print(f"✅ Fallback extracted {extracted} uniform frames") |
| except Exception as e: |
| print(f"Fallback extraction failed: {e}") |
| |
| print(f"✅ Generated {frame_counter-1} story-relevant frames") |
| |
| except TimeoutError: |
| print("⏰ Keyframe generation timed out, using fallback method...") |
| |
| for i, sub in enumerate(subs[:4], 1): |
| if frame_counter <= 16: |
| try: |
| |
| frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"), |
| sub.start.total_seconds(), sub.end.total_seconds(), 1) |
| if frames: |
| copy_and_rename_file(frames[0], final_dir, f"frame{frame_counter:03}.png") |
| print(f"📖 Fallback frame {frame_counter}: {sub.content}") |
| frame_counter += 1 |
| except: |
| pass |
| |
| print(f"✅ Generated {frame_counter-1} fallback frames") |
| |
| finally: |
| |
| signal.alarm(0) |
|
|
| def _select_story_relevant_frames(frames, highlight_scores, subtitle): |
| """Enhanced story-aware frame selection""" |
| try: |
| highlight_scores = list(highlight_scores) |
| |
| |
| sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)] |
| |
| |
| story_scores = [] |
| for i, frame_path in enumerate(frames): |
| story_score = _analyze_story_relevance(frame_path, highlight_scores[i], subtitle) |
| story_scores.append(story_score) |
| |
| |
| combined_scores = [] |
| for i in range(len(frames)): |
| combined_score = (highlight_scores[i] * 0.6) + (story_scores[i] * 0.4) |
| combined_scores.append(combined_score) |
| |
| |
| sorted_combined = [i[0] for i in sorted(enumerate(combined_scores), key=lambda x: x[1], reverse=True)] |
| |
| |
| num_frames_to_select = min(3, len(frames)) |
| return sorted_combined[:num_frames_to_select] |
| |
| except Exception as e: |
| print(f"Story selection failed: {e}") |
| |
| try: |
| highlight_scores = list(highlight_scores) |
| sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)] |
| return [sorted_indices[0]] if sorted_indices else [0] |
| except: |
| return [0] |
|
|
| def _analyze_story_relevance(frame_path, ai_score, subtitle): |
| """Analyze frame for story relevance""" |
| try: |
| img = cv2.imread(frame_path) |
| if img is None: |
| return ai_score |
| |
| |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
| face_score = len(faces) * 0.2 |
| |
| |
| motion_score = _detect_motion(img) * 0.15 |
| |
| |
| complexity_score = _analyze_scene_complexity(img) * 0.1 |
| |
| |
| content_score = _analyze_subtitle_relevance(subtitle.content) * 0.15 |
| |
| |
| story_score = ai_score + face_score + motion_score + complexity_score + content_score |
| |
| return min(story_score, 1.0) |
| |
| except Exception as e: |
| return ai_score |
|
|
| def _detect_motion(img): |
| """Detect motion/action in frame""" |
| try: |
| |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| edges = cv2.Canny(gray, 50, 150) |
| edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1]) |
| return min(edge_density * 10, 1.0) |
| except: |
| return 0.0 |
|
|
| def _analyze_scene_complexity(img): |
| """Analyze scene complexity""" |
| try: |
| |
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
| l_channel = lab[:,:,0] |
| complexity = np.std(l_channel) / 255.0 |
| return min(complexity * 2, 1.0) |
| except: |
| return 0.0 |
|
|
| def _analyze_subtitle_relevance(subtitle_text): |
| """Analyze subtitle content for story relevance""" |
| |
| important_keywords = [ |
| 'hello', 'goodbye', 'thank', 'please', 'sorry', 'yes', 'no', |
| 'love', 'hate', 'help', 'danger', 'important', 'secret', |
| 'action', 'fight', 'run', 'stop', 'go', 'come', 'leave' |
| ] |
| |
| text_lower = subtitle_text.lower() |
| relevance_score = 0.0 |
| |
| for keyword in important_keywords: |
| if keyword in text_lower: |
| relevance_score += 0.1 |
| |
| return min(relevance_score, 1.0) |
| |
|
|
| def black_bar_crop(): |
| ref_img_path = "frames/final/frame001.png" |
| |
| |
| if not os.path.exists(ref_img_path): |
| print(f"❌ Reference image not found: {ref_img_path}") |
| return 0, 0, 0, 0 |
| |
| x, y, w, h = get_black_bar_coordinates(ref_img_path) |
| |
| |
| folder_dir = "frames/final" |
| if not os.path.exists(folder_dir): |
| print(f"❌ Frames directory not found: {folder_dir}") |
| return x, y, w, h |
| |
| for image in os.listdir(folder_dir): |
| img_path = os.path.join("frames",'final',image) |
| if os.path.exists(img_path): |
| image_data = cv2.imread(img_path) |
| if image_data is not None: |
| |
| crop = image_data[y:y+h, x:x+w] |
| |
| cv2.imwrite(img_path, crop) |
| |
| return x, y, w, h |