leobcc commited on
Commit
a707cb2
·
1 Parent(s): 4d3bf91

Implementation of evaluate.py

Browse files
code/confs/dataset/video.yaml CHANGED
@@ -20,7 +20,7 @@ valid:
20
  batch_size: 1
21
  drop_last: False
22
  shuffle: False
23
- worker: 2
24
 
25
  num_sample : -1
26
  pixel_per_batch: 2048
 
20
  batch_size: 1
21
  drop_last: False
22
  shuffle: False
23
+ worker: 1
24
 
25
  num_sample : -1
26
  pixel_per_batch: 2048
code/evaluate.py CHANGED
@@ -2,31 +2,38 @@ import torch
2
  from v2a_model import V2AModel
3
  from skimage.metrics import structural_similarity as ssim
4
  from skimage.metrics import peak_signal_noise_ratio as psnr
 
 
5
 
6
- checkpoint_path = 'path/to/your/checkpoint.pth' # Replace with the actual path
7
- checkpoint = torch.load(checkpoint_path)
 
 
8
 
9
- model = V2AModel(opt)
10
 
11
- model.load_state_dict(checkpoint['model_state_dict'])
12
 
13
- model.eval()
14
 
15
- # Prepare your input data (replace this with your actual data preparation)
16
- input_data = torch.randn(1, 3, 256, 256) # Replace with your data
17
 
18
- # Pass the input through the model to get the outputs
19
- with torch.no_grad():
20
- output = model(input_data)
21
 
22
- # Assuming 'rgb_values' and 'fg_rgb_values' are the RGB images
23
- predicted_rgb_image = output['rgb_values'].cpu().numpy() # Convert to NumPy array
24
- target_rgb_image = output['fg_rgb_values'].cpu().numpy() # Convert to NumPy array
25
 
26
- # Calculate SSIM and PSNR
27
- ssim_value = ssim(target_rgb_image, predicted_rgb_image, multichannel=True)
28
- psnr_value = psnr(target_rgb_image, predicted_rgb_image)
29
 
30
- # Log or print the SSIM and PSNR values
31
- print(f'SSIM: {ssim_value:.4f}')
32
- print(f'PSNR: {psnr_value:.4f}')
 
 
 
 
2
  from v2a_model import V2AModel
3
  from skimage.metrics import structural_similarity as ssim
4
  from skimage.metrics import peak_signal_noise_ratio as psnr
5
+ import glob
6
+ from lib.datasets import create_dataset
7
 
8
+ @hydra.main(config_path="confs", config_name="base")
9
+ def main(opt):
10
+ checkpoint_path = sorted(glob.glob("checkpoints/*.ckpt"))[-1] # Replace with the actual path (if not specified uses the last checkpoint)
11
+ checkpoint = torch.load(checkpoint_path)
12
 
13
+ model = V2AModel(opt)
14
 
15
+ model.load_state_dict(checkpoint['model_state_dict'])
16
 
17
+ model.eval()
18
 
19
+ # Prepare your input data (replace this with your actual data preparation)
20
+ input_data = create_dataset(opt.dataset.metainfo, opt.dataset.test) # Replace with your data
21
 
22
+ # Pass the input through the model to get the outputs
23
+ with torch.no_grad():
24
+ output = model(input_data)
25
 
26
+ # Assuming 'rgb_values' and 'fg_rgb_values' are the RGB images
27
+ predicted_rgb_image = output['rgb_values'].cpu().numpy() # Convert to NumPy array
28
+ target_rgb_image = output['fg_rgb_values'].cpu().numpy() # Convert to NumPy array
29
 
30
+ # Calculate SSIM and PSNR
31
+ ssim_value = ssim(target_rgb_image, predicted_rgb_image, multichannel=True)
32
+ psnr_value = psnr(target_rgb_image, predicted_rgb_image)
33
 
34
+ # Log or print the SSIM and PSNR values
35
+ print(f'SSIM: {ssim_value:.4f}')
36
+ print(f'PSNR: {psnr_value:.4f}')
37
+
38
+ if __name__ == '__main__':
39
+ main()
visualization/vis.py CHANGED
@@ -45,9 +45,9 @@ def vis_dynamic_canonical_train(args):
45
  faces = []
46
  vertex_normals = []
47
  deformed_mesh_paths = sorted(glob.glob(f'{args.path}/*.ply'))
48
- print(deformed_mesh_paths)
49
  for deformed_mesh_path in deformed_mesh_paths:
50
  mesh = trimesh.load(deformed_mesh_path, process=False)
 
51
  # center the human
52
  mesh.vertices = mesh.vertices - mesh.vertices.mean(axis=0)
53
  vertices.append(mesh.vertices)
@@ -67,6 +67,7 @@ def vis_dynamic_canonical_train(args):
67
  viewer.scene.origin.enabled = False
68
  viewer.scene.floor.enabled = True
69
  viewer.run()
 
70
  if __name__ == '__main__':
71
  parser = argparse.ArgumentParser(description='3D Visualization')
72
  # static canonical mesh or dynamic sequence
 
45
  faces = []
46
  vertex_normals = []
47
  deformed_mesh_paths = sorted(glob.glob(f'{args.path}/*.ply'))
 
48
  for deformed_mesh_path in deformed_mesh_paths:
49
  mesh = trimesh.load(deformed_mesh_path, process=False)
50
+ print(mesh)
51
  # center the human
52
  mesh.vertices = mesh.vertices - mesh.vertices.mean(axis=0)
53
  vertices.append(mesh.vertices)
 
67
  viewer.scene.origin.enabled = False
68
  viewer.scene.floor.enabled = True
69
  viewer.run()
70
+
71
  if __name__ == '__main__':
72
  parser = argparse.ArgumentParser(description='3D Visualization')
73
  # static canonical mesh or dynamic sequence