| import torch |
|
|
| import CrosswalkDataset as Dataset |
| import ClassifierModel as Model |
| import Utilities as Utils |
|
|
|
|
| def train_model_v0(model_to_train, dataset, epoch_number=25, loss_func=Utils.BasicClassificationLoss, |
| batch_size=16, save=False): |
| optimiser = torch.optim.Adam(model_to_train.parameters(), lr=0.001) |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size) |
|
|
| loss_function = loss_func() |
|
|
| for epoch in range(epoch_number): |
| model_to_train.train() |
| running_loss = 0.0 |
|
|
| for images, gt_labels in dataloader: |
| optimiser.zero_grad() |
|
|
| predictions = model_to_train(images) |
|
|
| batch_loss = loss_function(predictions, gt_labels) |
| batch_loss.backward() |
|
|
| running_loss += batch_loss |
| optimiser.step() |
|
|
| print(f"Epoch [{epoch + 1} of {epoch_number}] finished, with loss {running_loss / len(dataloader)} in " |
| f"len {len(dataloader) * batch_size}") |
|
|
| Utils.save_model(model_to_train, optimiser) |
| return model_to_train |
|
|
|
|
| |
| def train_model_v1(model_to_train, dataset, epoch_number=10, loss_func=Utils.BasicClassificationLoss, |
| batch_size=16, save=False): |
| optimiser = torch.optim.Adam(model_to_train.parameters(), lr=0.001) |
| scheduler = torch.optim.lr_scheduler.ExponentialLR(optimiser, gamma=0.95) |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size) |
| loss_function = loss_func() |
|
|
| for epoch in range(epoch_number): |
| model_to_train.train() |
| running_loss = 0.0 |
|
|
| for images, gt_labels in dataloader: |
| optimiser.zero_grad() |
|
|
| predictions = model_to_train(images) |
|
|
| batch_loss = loss_function(predictions, gt_labels) |
| batch_loss.backward() |
|
|
| running_loss += batch_loss |
| optimiser.step() |
|
|
| scheduler.step() |
|
|
| print(f"Epoch [{epoch + 1} of {epoch_number}] finished, with loss {running_loss / len(dataloader)} in " |
| f"len {len(dataloader) * batch_size}") |
|
|
| Utils.save_model(model_to_train, optimiser) |
| return model_to_train |
|
|
|
|
| model = Model.BasicClassificationModel(image_size=416) |
| |
| crosswalk_dataset = Dataset.CrosswalkDataset("Crosswalk.v7-crosswalk-t3.tensorflow/train/_annotations.csv", |
| "Crosswalk.v7-crosswalk-t3.tensorflow/train") |
| model = train_model_v1(model, crosswalk_dataset, save=True) |
|
|