| | """ |
| | 快速训练脚本 - 用于测试和调试 |
| | 只使用数据集的前100个样本进行快速多 epoch 测试 |
| | """ |
| | import argparse |
| | import os, sys |
| | import math |
| | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | sys.path.append(BASE_DIR) |
| |
|
| | import pprint |
| | import time |
| | import torch |
| | import torch.nn.parallel |
| | from torch.cuda import amp |
| | import torch.backends.cudnn as cudnn |
| | import torch.optim |
| | import torch.utils.data |
| | import torchvision.transforms as transforms |
| | import numpy as np |
| | from tensorboardX import SummaryWriter |
| |
|
| | import lib.dataset as dataset |
| | from lib.config import cfg |
| | from lib.config import update_config |
| | from lib.core.loss import get_loss |
| | from lib.core.function import train |
| | from lib.core.function import validate |
| | from lib.core.general import fitness |
| | from lib.models import get_net |
| | from lib.utils.utils import get_optimizer |
| | from lib.utils.utils import save_checkpoint |
| | from lib.utils.utils import create_logger, select_device |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='Quick train for testing') |
| | |
| | parser.add_argument('--config', type=str, default='yolov11', |
| | help='config to use: default or yolov11') |
| | parser.add_argument('--samples', type=int, default=100, |
| | help='number of samples to use for quick test') |
| | parser.add_argument('--epochs', type=int, default=10, |
| | help='number of epochs for quick test') |
| | parser.add_argument('--batch-size', type=int, default=4, |
| | help='batch size for quick test') |
| | parser.add_argument('--yolo-scale', type=str, default='s', |
| | choices=['n', 's', 'm', 'l', 'x'], |
| | help='YOLOv11 scale (only used if config=yolov11)') |
| | parser.add_argument('--freeze-backbone', action='store_true', |
| | help='freeze YOLOv11 backbone') |
| | parser.add_argument('--workers', type=int, default=0, |
| | help='number of data loading workers') |
| | |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | class SubsetDataset(torch.utils.data.Dataset): |
| | """数据集子集包装器""" |
| | def __init__(self, dataset, num_samples): |
| | self.dataset = dataset |
| | self.num_samples = min(num_samples, len(dataset)) |
| | |
| | def __len__(self): |
| | return self.num_samples |
| | |
| | def __getitem__(self, idx): |
| | if idx >= self.num_samples: |
| | raise IndexError |
| | return self.dataset[idx] |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | |
| | |
| | if args.config == 'yolov11': |
| | from lib.config.yolov11 import cfg |
| | |
| | cfg.MODEL.YOLOV11_SCALE = args.yolo_scale |
| | cfg.MODEL.YOLOV11_WEIGHTS = f'weights/yolo11{args.yolo_scale}.pt' |
| | cfg.MODEL.FREEZE_BACKBONE = args.freeze_backbone |
| | else: |
| | from lib.config.default import _C as cfg |
| | |
| | |
| | cfg.TRAIN.BEGIN_EPOCH = 0 |
| | cfg.TRAIN.END_EPOCH = args.epochs |
| | cfg.TRAIN.BATCH_SIZE_PER_GPU = args.batch_size |
| | cfg.WORKERS = args.workers |
| | cfg.PRINT_FREQ = 5 |
| | |
| | |
| | logger, final_output_dir, tb_log_dir = create_logger( |
| | cfg, cfg.LOG_DIR, 'quick_train' |
| | ) |
| | |
| | logger.info("="*80) |
| | logger.info("QUICK TRAIN MODE - Testing Configuration") |
| | logger.info("="*80) |
| | logger.info(f"Config: {args.config}") |
| | logger.info(f"Samples: {args.samples}") |
| | logger.info(f"Epochs: {args.epochs}") |
| | logger.info(f"Batch size: {args.batch_size}") |
| | if args.config == 'yolov11': |
| | logger.info(f"YOLOv11 scale: {args.yolo_scale}") |
| | logger.info(f"Freeze backbone: {args.freeze_backbone}") |
| | logger.info("="*80) |
| | |
| | writer_dict = { |
| | 'writer': SummaryWriter(log_dir=tb_log_dir), |
| | 'train_global_steps': 0, |
| | 'valid_global_steps': 0, |
| | } |
| | |
| | |
| | cudnn.benchmark = cfg.CUDNN.BENCHMARK |
| | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC |
| | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED |
| | |
| | |
| | logger.info("Building model...") |
| | device = select_device(logger, batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU) |
| | |
| | if hasattr(cfg.MODEL, 'USE_YOLOV11') and cfg.MODEL.USE_YOLOV11: |
| | model = get_net( |
| | cfg, |
| | yolo_scale=cfg.MODEL.YOLOV11_SCALE, |
| | yolo_weights_path=cfg.MODEL.YOLOV11_WEIGHTS, |
| | freeze_backbone=cfg.MODEL.FREEZE_BACKBONE |
| | ).to(device) |
| | else: |
| | model = get_net(cfg).to(device) |
| | |
| | logger.info("Model created successfully") |
| |
|
| | print("++++++++++++++++++++++") |
| | print(model.model[model.detector_index]) |
| | |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | logger.info(f"Total parameters: {total_params:,}") |
| | logger.info(f"Trainable parameters: {trainable_params:,}") |
| | logger.info(f"Frozen parameters: {total_params - trainable_params:,}") |
| | |
| | |
| | criterion = get_loss(cfg, device=device) |
| | optimizer = get_optimizer(cfg, model) |
| | |
| | |
| | lf = lambda x: ((1 + math.cos(x * math.pi / cfg.TRAIN.END_EPOCH)) / 2) * \ |
| | (1 - cfg.TRAIN.LRF) + cfg.TRAIN.LRF |
| | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) |
| | |
| | |
| | logger.info("Loading dataset...") |
| | normalize = transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| | ) |
| | |
| | train_dataset = eval('dataset.' + cfg.DATASET.DATASET)( |
| | cfg=cfg, |
| | is_train=True, |
| | inputsize=cfg.MODEL.IMAGE_SIZE, |
| | transform=transforms.Compose([ |
| | transforms.ToTensor(), |
| | normalize, |
| | ]) |
| | ) |
| | |
| | |
| | train_dataset = SubsetDataset(train_dataset, args.samples) |
| | logger.info(f"Using {len(train_dataset)} training samples") |
| | |
| | train_loader = torch.utils.data.DataLoader( |
| | train_dataset, |
| | batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU, |
| | shuffle=True, |
| | num_workers=cfg.WORKERS, |
| | pin_memory=cfg.PIN_MEMORY, |
| | collate_fn=dataset.AutoDriveDataset.collate_fn |
| | ) |
| | |
| | |
| | valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)( |
| | cfg=cfg, |
| | is_train=False, |
| | inputsize=cfg.MODEL.IMAGE_SIZE, |
| | transform=transforms.Compose([ |
| | transforms.ToTensor(), |
| | normalize, |
| | ]) |
| | ) |
| | valid_dataset = SubsetDataset(valid_dataset, args.samples // 2) |
| | logger.info(f"Using {len(valid_dataset)} validation samples") |
| | |
| | valid_loader = torch.utils.data.DataLoader( |
| | valid_dataset, |
| | batch_size=cfg.TEST.BATCH_SIZE_PER_GPU, |
| | shuffle=False, |
| | num_workers=cfg.WORKERS, |
| | pin_memory=cfg.PIN_MEMORY, |
| | collate_fn=dataset.AutoDriveDataset.collate_fn |
| | ) |
| | |
| | |
| | scaler = amp.GradScaler(enabled=device.type != 'cpu') |
| | |
| | |
| | logger.info("Starting training...") |
| | logger.info("="*80) |
| | |
| | best_fitness = 0.0 |
| | num_batch = len(train_loader) |
| | num_warmup = max(round(cfg.TRAIN.WARMUP_EPOCHS * num_batch), 1000) |
| | |
| | for epoch in range(cfg.TRAIN.BEGIN_EPOCH, cfg.TRAIN.END_EPOCH): |
| | logger.info(f"\n{'='*80}") |
| | logger.info(f"Epoch {epoch}/{cfg.TRAIN.END_EPOCH-1}") |
| | logger.info(f"{'='*80}") |
| | |
| | |
| | train( |
| | cfg, train_loader, model, criterion, optimizer, |
| | scaler, epoch, num_batch, num_warmup, |
| | writer_dict, logger, device |
| | ) |
| | |
| | |
| | lr_scheduler.step() |
| | |
| | |
| | if (epoch % cfg.TRAIN.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH - 1): |
| | logger.info("\nValidating...") |
| | da_segment_results, ll_segment_results, detect_results, total_loss, maps, times = validate( |
| | epoch, cfg, valid_loader, valid_dataset, model, criterion, |
| | final_output_dir, tb_log_dir, writer_dict, logger, device |
| | ) |
| | |
| | |
| | fi = fitness(np.array(detect_results).reshape(1, -1)) |
| | logger.info(f"Fitness: {fi.item():.4f}") |
| | |
| | |
| | if fi > best_fitness: |
| | best_fitness = fi |
| | |
| | |
| | logger.info(f"New best fitness: {best_fitness.item():.4f}") |
| | save_checkpoint( |
| | epoch= epoch + 1, |
| | name='111', |
| | model=model, |
| | optimizer=optimizer, |
| | output_dir=final_output_dir, |
| | filename='checkpoint_best.pth', |
| | is_best=True |
| | ) |
| | |
| | |
| | save_checkpoint( |
| | epoch=epoch, |
| | name=cfg.MODEL.NAME, |
| | model=model, |
| | |
| | |
| | optimizer=optimizer, |
| | output_dir=final_output_dir, |
| | filename=f'epoch-{epoch}.pth' |
| | ) |
| | |
| | logger.info("\n" + "="*80) |
| | logger.info("Training completed!") |
| | logger.info(f"Best fitness: {best_fitness.item():.4f}") |
| | logger.info(f"Results saved to: {final_output_dir}") |
| | logger.info("="*80) |
| | |
| | writer_dict['writer'].close() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|