| import torch |
| from argparse import ArgumentParser |
| import os, json |
| from tqdm import tqdm |
|
|
| current_dir = os.path.abspath(os.path.dirname(__file__)) |
|
|
| from datasets import NWPUTest, Resize2Multiple |
| from models import get_model |
| from utils import get_config, sliding_window_predict |
|
|
| parser = ArgumentParser(description="Test a trained model on the NWPU-Crowd test set.") |
| |
| parser.add_argument("--model", type=str, default="vgg19_ae", help="The model to train.") |
| parser.add_argument("--input_size", type=int, default=448, help="The size of the input image.") |
| parser.add_argument("--reduction", type=int, default=8, choices=[8, 16, 32], help="The reduction factor of the model.") |
| parser.add_argument("--regression", action="store_true", help="Use blockwise regression instead of classification.") |
| parser.add_argument("--truncation", type=int, default=None, help="The truncation of the count.") |
| parser.add_argument("--anchor_points", type=str, default="average", choices=["average", "middle"], help="The representative count values of bins.") |
| parser.add_argument("--prompt_type", type=str, default="word", choices=["word", "number"], help="The prompt type for CLIP.") |
| parser.add_argument("--granularity", type=str, default="fine", choices=["fine", "dynamic", "coarse"], help="The granularity of bins.") |
| parser.add_argument("--num_vpt", type=int, default=32, help="The number of visual prompt tokens.") |
| parser.add_argument("--vpt_drop", type=float, default=0.0, help="The dropout rate for visual prompt tokens.") |
| parser.add_argument("--shallow_vpt", action="store_true", help="Use shallow visual prompt tokens.") |
| parser.add_argument("--weight_path", type=str, required=True, help="The path to the weights of the model.") |
|
|
| |
| parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.") |
| parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.") |
| parser.add_argument("--window_size", type=int, default=None, help="The window size for in prediction.") |
| parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.") |
| parser.add_argument("--zero_pad_to_multiple", action="store_true", help="Zero pad the image to the nearest multiple of the input size.") |
|
|
| parser.add_argument("--device", type=str, default="cuda", help="The device to use for evaluation.") |
| parser.add_argument("--num_workers", type=int, default=4, help="The number of workers for the data loader.") |
|
|
|
|
| def main(args: ArgumentParser): |
| print("Testing a trained model on the NWPU-Crowd test set.") |
| device = torch.device(args.device) |
| _ = get_config(vars(args).copy(), mute=False) |
| if args.regression: |
| bins, anchor_points = None, None |
| else: |
| with open(os.path.join(current_dir, "configs", f"reduction_{args.reduction}.json"), "r") as f: |
| config = json.load(f)[str(args.truncation)]["nwpu"] |
| bins = config["bins"][args.granularity] |
| anchor_points = config["anchor_points"][args.granularity]["average"] if args.anchor_points == "average" else config["anchor_points"][args.granularity]["middle"] |
| bins = [(float(b[0]), float(b[1])) for b in bins] |
| anchor_points = [float(p) for p in anchor_points] |
|
|
| args.bins = bins |
| args.anchor_points = anchor_points |
|
|
| model = get_model( |
| backbone=args.model, |
| input_size=args.input_size, |
| reduction=args.reduction, |
| bins=bins, |
| anchor_points=anchor_points, |
| prompt_type=args.prompt_type, |
| num_vpt=args.num_vpt, |
| vpt_drop=args.vpt_drop, |
| deep_vpt=not args.shallow_vpt |
| ) |
| state_dict = torch.load(args.weight_path, map_location="cpu") |
| state_dict = state_dict if "best" in os.path.basename(args.weight_path) else state_dict["model_state_dict"] |
| model.load_state_dict(state_dict, strict=True) |
| model = model.to(device) |
| model.eval() |
|
|
| sliding_window = args.sliding_window |
| if args.sliding_window: |
| window_size = args.input_size |
| stride = window_size // 2 if args.stride is None else args.stride |
| if args.resize_to_multiple: |
| transforms = Resize2Multiple(base=args.input_size) |
| else: |
| transforms = None |
| else: |
| window_size, stride = None, None |
| transforms = None |
|
|
| dataset = NWPUTest(transforms=transforms, return_filename=True) |
|
|
| image_ids = [] |
| preds = [] |
|
|
| for idx in tqdm(range(len(dataset)), desc="Testing on NWPU"): |
| image, image_path = dataset[idx] |
| image = image.unsqueeze(0) |
| image = image.to(device) |
|
|
| with torch.set_grad_enabled(False): |
| if sliding_window: |
| pred_density = sliding_window_predict(model, image, window_size, stride) |
| else: |
| pred_density = model(image) |
|
|
| pred_count = pred_density.sum(dim=(1, 2, 3)).item() |
|
|
| image_ids.append(os.path.basename(image_path).split(".")[0]) |
| preds.append(pred_count) |
|
|
| result_dir = os.path.join(current_dir, "nwpu_test_results") |
| os.makedirs(result_dir, exist_ok=True) |
| weights_dir, weights_name = os.path.split(args.weight_path) |
| model_name = os.path.split(weights_dir)[-1] |
| result_path = os.path.join(result_dir, f"{model_name}_{weights_name.split('.')[0]}.txt") |
|
|
| with open(result_path, "w") as f: |
| for idx, (image_id, pred) in enumerate(zip(image_ids, preds)): |
| if idx != len(image_ids) - 1: |
| f.write(f"{image_id} {pred}\n") |
| else: |
| f.write(f"{image_id} {pred}") |
|
|
|
|
| if __name__ == "__main__": |
| args = parser.parse_args() |
| args.model = args.model.lower() |
|
|
| if args.regression: |
| args.truncation = None |
| args.anchor_points = None |
| args.bins = None |
| args.prompt_type = None |
| args.granularity = None |
|
|
| if "clip_vit" not in args.model: |
| args.num_vpt = None |
| args.vpt_drop = None |
| args.shallow_vpt = None |
| |
| if "clip" not in args.model: |
| args.prompt_type = None |
|
|
| if args.sliding_window: |
| args.window_size = args.input_size if args.window_size is None else args.window_size |
| args.stride = args.input_size if args.stride is None else args.stride |
| assert not (args.zero_pad_to_multiple and args.resize_to_multiple), "Cannot use both zero pad and resize to multiple." |
|
|
| else: |
| args.window_size = None |
| args.stride = None |
| args.zero_pad_to_multiple = False |
| args.resize_to_multiple = False |
| |
| main(args) |
|
|
| |
| |