leobcc commited on
Commit ·
a707cb2
1
Parent(s): 4d3bf91
Implementation of evaluate.py
Browse files- code/confs/dataset/video.yaml +1 -1
- code/evaluate.py +26 -19
- visualization/vis.py +2 -1
code/confs/dataset/video.yaml
CHANGED
|
@@ -20,7 +20,7 @@ valid:
|
|
| 20 |
batch_size: 1
|
| 21 |
drop_last: False
|
| 22 |
shuffle: False
|
| 23 |
-
worker:
|
| 24 |
|
| 25 |
num_sample : -1
|
| 26 |
pixel_per_batch: 2048
|
|
|
|
| 20 |
batch_size: 1
|
| 21 |
drop_last: False
|
| 22 |
shuffle: False
|
| 23 |
+
worker: 1
|
| 24 |
|
| 25 |
num_sample : -1
|
| 26 |
pixel_per_batch: 2048
|
code/evaluate.py
CHANGED
|
@@ -2,31 +2,38 @@ import torch
|
|
| 2 |
from v2a_model import V2AModel
|
| 3 |
from skimage.metrics import structural_similarity as ssim
|
| 4 |
from skimage.metrics import peak_signal_noise_ratio as psnr
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
model = V2AModel(opt)
|
| 10 |
|
| 11 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 12 |
|
| 13 |
-
model.eval()
|
| 14 |
|
| 15 |
-
# Prepare your input data (replace this with your actual data preparation)
|
| 16 |
-
input_data =
|
| 17 |
|
| 18 |
-
# Pass the input through the model to get the outputs
|
| 19 |
-
with torch.no_grad():
|
| 20 |
-
|
| 21 |
|
| 22 |
-
# Assuming 'rgb_values' and 'fg_rgb_values' are the RGB images
|
| 23 |
-
predicted_rgb_image = output['rgb_values'].cpu().numpy() # Convert to NumPy array
|
| 24 |
-
target_rgb_image = output['fg_rgb_values'].cpu().numpy() # Convert to NumPy array
|
| 25 |
|
| 26 |
-
# Calculate SSIM and PSNR
|
| 27 |
-
ssim_value = ssim(target_rgb_image, predicted_rgb_image, multichannel=True)
|
| 28 |
-
psnr_value = psnr(target_rgb_image, predicted_rgb_image)
|
| 29 |
|
| 30 |
-
# Log or print the SSIM and PSNR values
|
| 31 |
-
print(f'SSIM: {ssim_value:.4f}')
|
| 32 |
-
print(f'PSNR: {psnr_value:.4f}')
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from v2a_model import V2AModel
|
| 3 |
from skimage.metrics import structural_similarity as ssim
|
| 4 |
from skimage.metrics import peak_signal_noise_ratio as psnr
|
| 5 |
+
import glob
|
| 6 |
+
from lib.datasets import create_dataset
|
| 7 |
|
| 8 |
+
@hydra.main(config_path="confs", config_name="base")
|
| 9 |
+
def main(opt):
|
| 10 |
+
checkpoint_path = sorted(glob.glob("checkpoints/*.ckpt"))[-1] # Replace with the actual path (if not specified uses the last checkpoint)
|
| 11 |
+
checkpoint = torch.load(checkpoint_path)
|
| 12 |
|
| 13 |
+
model = V2AModel(opt)
|
| 14 |
|
| 15 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 16 |
|
| 17 |
+
model.eval()
|
| 18 |
|
| 19 |
+
# Prepare your input data (replace this with your actual data preparation)
|
| 20 |
+
input_data = create_dataset(opt.dataset.metainfo, opt.dataset.test) # Replace with your data
|
| 21 |
|
| 22 |
+
# Pass the input through the model to get the outputs
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
output = model(input_data)
|
| 25 |
|
| 26 |
+
# Assuming 'rgb_values' and 'fg_rgb_values' are the RGB images
|
| 27 |
+
predicted_rgb_image = output['rgb_values'].cpu().numpy() # Convert to NumPy array
|
| 28 |
+
target_rgb_image = output['fg_rgb_values'].cpu().numpy() # Convert to NumPy array
|
| 29 |
|
| 30 |
+
# Calculate SSIM and PSNR
|
| 31 |
+
ssim_value = ssim(target_rgb_image, predicted_rgb_image, multichannel=True)
|
| 32 |
+
psnr_value = psnr(target_rgb_image, predicted_rgb_image)
|
| 33 |
|
| 34 |
+
# Log or print the SSIM and PSNR values
|
| 35 |
+
print(f'SSIM: {ssim_value:.4f}')
|
| 36 |
+
print(f'PSNR: {psnr_value:.4f}')
|
| 37 |
+
|
| 38 |
+
if __name__ == '__main__':
|
| 39 |
+
main()
|
visualization/vis.py
CHANGED
|
@@ -45,9 +45,9 @@ def vis_dynamic_canonical_train(args):
|
|
| 45 |
faces = []
|
| 46 |
vertex_normals = []
|
| 47 |
deformed_mesh_paths = sorted(glob.glob(f'{args.path}/*.ply'))
|
| 48 |
-
print(deformed_mesh_paths)
|
| 49 |
for deformed_mesh_path in deformed_mesh_paths:
|
| 50 |
mesh = trimesh.load(deformed_mesh_path, process=False)
|
|
|
|
| 51 |
# center the human
|
| 52 |
mesh.vertices = mesh.vertices - mesh.vertices.mean(axis=0)
|
| 53 |
vertices.append(mesh.vertices)
|
|
@@ -67,6 +67,7 @@ def vis_dynamic_canonical_train(args):
|
|
| 67 |
viewer.scene.origin.enabled = False
|
| 68 |
viewer.scene.floor.enabled = True
|
| 69 |
viewer.run()
|
|
|
|
| 70 |
if __name__ == '__main__':
|
| 71 |
parser = argparse.ArgumentParser(description='3D Visualization')
|
| 72 |
# static canonical mesh or dynamic sequence
|
|
|
|
| 45 |
faces = []
|
| 46 |
vertex_normals = []
|
| 47 |
deformed_mesh_paths = sorted(glob.glob(f'{args.path}/*.ply'))
|
|
|
|
| 48 |
for deformed_mesh_path in deformed_mesh_paths:
|
| 49 |
mesh = trimesh.load(deformed_mesh_path, process=False)
|
| 50 |
+
print(mesh)
|
| 51 |
# center the human
|
| 52 |
mesh.vertices = mesh.vertices - mesh.vertices.mean(axis=0)
|
| 53 |
vertices.append(mesh.vertices)
|
|
|
|
| 67 |
viewer.scene.origin.enabled = False
|
| 68 |
viewer.scene.floor.enabled = True
|
| 69 |
viewer.run()
|
| 70 |
+
|
| 71 |
if __name__ == '__main__':
|
| 72 |
parser = argparse.ArgumentParser(description='3D Visualization')
|
| 73 |
# static canonical mesh or dynamic sequence
|