| import gradio as gr |
| import torch |
| import os |
| import tempfile |
| import numpy as np |
| from models import Model |
| from dataset import extract_features |
| from eval import predict |
|
|
| |
| def load_model(checkpoint_path='checkpoint/ckp_best.pth.tar'): |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| model = Model(**checkpoint['config']) |
| model.load_state_dict(checkpoint['state_dict']) |
| model.eval() |
| return model |
|
|
| model = load_model() |
|
|
| def process_video(video_file): |
| |
| temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name |
| with open(temp_path, "wb") as f: |
| f.write(video_file.read()) |
|
|
| |
| features = extract_features(temp_path) |
| |
| npz_path = temp_path.replace(".mp4", ".npz") |
| np.savez(npz_path, features=features) |
|
|
| |
| predictions = predict(model, npz_path) |
|
|
| |
| results = "\n".join([ |
| f"{label}: {start:.2f}s - {end:.2f}s" |
| for label, start, end in predictions |
| ]) |
| |
| os.remove(temp_path) |
| os.remove(npz_path) |
| return results |
|
|
| demo = gr.Interface( |
| fn=process_video, |
| inputs=gr.Video(label="Upload a video"), |
| outputs=gr.Textbox(label="Detected Actions"), |
| title="Temporal Action Localization" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|