| import os |
| import sys |
| import torch |
| import wandb |
| import matplotlib.pyplot as plt |
| import pytorch_lightning as pl |
| from torch.optim import AdamW |
| from torchmetrics.functional import mean_squared_error |
| from torchdyn.core import NeuralODE |
| from .networks.utils import flow_model_torch_wrapper |
| from .utils import wasserstein, plot_lidar |
| from .ema import EMA |
|
|
| class BranchFlowNetTrainBase(pl.LightningModule): |
| def __init__( |
| self, |
| flow_matcher, |
| flow_nets, |
| skipped_time_points=None, |
| ot_sampler=None, |
| args=None, |
| ): |
| super().__init__() |
| self.args = args |
| |
| self.flow_matcher = flow_matcher |
| self.flow_nets = flow_nets |
| self.ot_sampler = ot_sampler |
| self.skipped_time_points = skipped_time_points |
|
|
| self.optimizer_name = args.flow_optimizer |
| self.lr = args.flow_lr |
| self.weight_decay = args.flow_weight_decay |
| self.whiten = args.whiten |
| self.working_dir = args.working_dir |
| |
| |
| self.branches = len(flow_nets) |
|
|
| def forward(self, t, xt, branch_idx): |
| |
| return self.flow_nets[branch_idx](t, xt) |
|
|
| def _compute_loss(self, main_batch): |
| |
| x0s = [main_batch["x0"][0]] |
| w0s = [main_batch["x0"][1]] |
| |
| x1s_list = [] |
| w1s_list = [] |
| |
| if self.branches > 1: |
| for i in range(self.branches): |
| x1s_list.append([main_batch[f"x1_{i+1}"][0]]) |
| w1s_list.append([main_batch[f"x1_{i+1}"][1]]) |
| else: |
| x1s_list.append([main_batch["x1"][0]]) |
| w1s_list.append([main_batch["x1"][1]]) |
| |
| assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" |
| |
| loss = 0 |
| for branch_idx in range(self.branches): |
| ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx) |
|
|
| t = torch.cat(ts) |
| xt = torch.cat(xts) |
| ut = torch.cat(uts) |
| vt = self(t[:, None], xt, branch_idx) |
|
|
| loss += mean_squared_error(vt, ut) |
|
|
| return loss |
|
|
| def _process_flow(self, x0s, x1s, branch_idx): |
| ts, xts, uts = [], [], [] |
| t_start = self.timesteps[0] |
|
|
| for i, (x0, x1) in enumerate(zip(x0s, x1s)): |
| |
| x0, x1 = torch.squeeze(x0), torch.squeeze(x1) |
|
|
| if self.ot_sampler is not None: |
| x0, x1 = self.ot_sampler.sample_plan( |
| x0, |
| x1, |
| replace=True, |
| ) |
| if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]: |
| t_start_next = self.timesteps[i + 2] |
| else: |
| t_start_next = self.timesteps[i + 1] |
| |
| |
| t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow( |
| x0, x1, t_start, t_start_next, branch_idx |
| ) |
|
|
| ts.append(t) |
|
|
| xts.append(xt) |
| uts.append(ut) |
| t_start = t_start_next |
| return ts, xts, uts |
|
|
| def training_step(self, batch, batch_idx): |
| |
| if isinstance(batch, (list, tuple)): |
| batch = batch[0] |
| if isinstance(batch, dict) and "train_samples" in batch: |
| main_batch = batch["train_samples"] |
| if isinstance(main_batch, tuple): |
| main_batch = main_batch[0] |
| else: |
| |
| main_batch = batch.get("train_samples", batch) |
| |
| print("Main batch length") |
| print(len(main_batch["x0"])) |
| |
| |
| self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
| loss = self._compute_loss(main_batch) |
| if self.flow_matcher.alpha != 0: |
| self.log( |
| "FlowNet/mean_geopath_cfm", |
| (self.flow_matcher.geopath_net_output.abs().mean()), |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| ) |
|
|
| self.log( |
| "FlowNet/train_loss_cfm", |
| loss, |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| logger=True, |
| ) |
| |
| |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| |
| if isinstance(batch, (list, tuple)): |
| batch = batch[0] |
| if isinstance(batch, dict) and "val_samples" in batch: |
| main_batch = batch["val_samples"] |
| if isinstance(main_batch, tuple): |
| main_batch = main_batch[0] |
| else: |
| |
| main_batch = batch.get("val_samples", batch) |
| |
| self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
| val_loss = self._compute_loss(main_batch) |
| self.log( |
| "FlowNet/val_loss_cfm", |
| val_loss, |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| logger=True, |
| ) |
| return val_loss |
|
|
| def optimizer_step(self, *args, **kwargs): |
| super().optimizer_step(*args, **kwargs) |
| |
| for net in self.flow_nets: |
| if isinstance(net, EMA): |
| net.update_ema() |
|
|
| def configure_optimizers(self): |
| if self.optimizer_name == "adamw": |
| optimizer = AdamW( |
| self.parameters(), |
| lr=self.lr, |
| weight_decay=self.weight_decay, |
| ) |
| elif self.optimizer_name == "adam": |
| optimizer = torch.optim.Adam( |
| self.parameters(), |
| lr=self.lr, |
| ) |
|
|
| return optimizer |
|
|
|
|
| class FlowNetTrainTrajectory(BranchFlowNetTrainBase): |
| def test_step(self, batch, batch_idx): |
| data_type = self.args.data_type |
| node = NeuralODE( |
| flow_model_torch_wrapper(self.flow_nets), |
| solver="euler", |
| sensitivity="adjoint", |
| atol=1e-5, |
| rtol=1e-5, |
| ) |
|
|
| t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None |
| if t_exclude is not None: |
| traj = node.trajectory( |
| batch[t_exclude - 1], |
| t_span=torch.linspace( |
| self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101 |
| ), |
| ) |
| X_mid_pred = traj[-1] |
| traj = node.trajectory( |
| batch[t_exclude - 1], |
| t_span=torch.linspace( |
| self.timesteps[t_exclude - 1], |
| self.timesteps[t_exclude + 1], |
| 101, |
| ), |
| ) |
| |
| EMD = wasserstein(X_mid_pred, batch[t_exclude], p=1) |
| self.final_EMD = EMD |
|
|
| self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True) |
|
|
| class FlowNetTrainCell(BranchFlowNetTrainBase): |
| def test_step(self, batch, batch_idx): |
| x0 = batch[0]["test_samples"][0]["x0"][0] |
| dataset_points = batch[0]["test_samples"][0]["dataset"][0] |
| t_span = torch.linspace(0, 1, 101) |
|
|
| all_trajs = [] |
|
|
| for i, flow_net in enumerate(self.flow_nets): |
| node = NeuralODE( |
| flow_model_torch_wrapper(flow_net), |
| solver="euler", |
| sensitivity="adjoint", |
| ) |
|
|
| with torch.no_grad(): |
| traj = node.trajectory(x0, t_span).cpu() |
|
|
| if self.whiten: |
| traj_shape = traj.shape |
| traj = traj.reshape(-1, traj.shape[-1]) |
| traj = self.trainer.datamodule.scaler.inverse_transform( |
| traj.cpu().detach().numpy() |
| ).reshape(traj_shape) |
| dataset_points = self.trainer.datamodule.scaler.inverse_transform( |
| dataset_points.cpu().detach().numpy() |
| ) |
|
|
| traj = torch.tensor(traj) |
| traj = torch.transpose(traj, 0, 1) |
| all_trajs.append(traj) |
|
|
| dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2] |
|
|
| |
| fig, ax = plt.subplots(figsize=(6, 5)) |
| dataset_2d = dataset_2d.cpu().numpy() |
| ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1) |
| for traj in all_trajs: |
| traj_2d = traj[..., :2] |
| for i in range(traj_2d.shape[0]): |
| ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2) |
| ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3) |
| ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3) |
|
|
| ax.set_title("All Branch Trajectories (2D) with Dataset") |
| ax.set_xlabel("x") |
| ax.set_ylabel("y") |
| plt.axis("equal") |
| handles, labels = ax.get_legend_handles_labels() |
| if labels: |
| ax.legend() |
| |
| run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| save_path = os.path.join(results_dir, 'figures') |
| |
| os.makedirs(save_path, exist_ok=True) |
| plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300) |
| plt.close() |
|
|
| |
| for i, traj in enumerate(all_trajs): |
| traj_2d = traj[..., :2] |
| fig, ax = plt.subplots(figsize=(6, 5)) |
| ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1) |
| for j in range(traj_2d.shape[0]): |
| ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2) |
| ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3) |
| ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3) |
|
|
| ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset") |
| ax.set_xlabel("x") |
| ax.set_ylabel("y") |
| plt.axis("equal") |
| handles, labels = ax.get_legend_handles_labels() |
| if labels: |
| ax.legend() |
| plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300) |
| plt.close() |
|
|
| class FlowNetTrainLidar(BranchFlowNetTrainBase): |
| def test_step(self, batch, batch_idx): |
| |
| if isinstance(batch, dict): |
| main_batch = batch["test_samples"][0] |
| metric_batch = batch["metric_samples"][0] |
| else: |
| |
| main_batch = batch[0][0] |
| metric_batch = batch[1][0] |
| |
| x0 = main_batch["x0"][0] |
| cloud_points = main_batch["dataset"][0] |
| t_span = torch.linspace(0, 1, 101) |
|
|
| all_trajs = [] |
|
|
| for i, flow_net in enumerate(self.flow_nets): |
| node = NeuralODE( |
| flow_model_torch_wrapper(flow_net), |
| solver="euler", |
| sensitivity="adjoint", |
| ) |
|
|
| with torch.no_grad(): |
| traj = node.trajectory(x0, t_span).cpu() |
|
|
| if self.whiten: |
| traj_shape = traj.shape |
| traj = traj.reshape(-1, 3) |
| traj = self.trainer.datamodule.scaler.inverse_transform( |
| traj.cpu().detach().numpy() |
| ).reshape(traj_shape) |
|
|
| traj = torch.tensor(traj) |
| traj = torch.transpose(traj, 0, 1) |
| all_trajs.append(traj) |
|
|
| |
| if self.whiten: |
| cloud_points = torch.tensor( |
| self.trainer.datamodule.scaler.inverse_transform( |
| cloud_points.cpu().detach().numpy() |
| ) |
| ) |
|
|
| |
| run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| lidar_fig_dir = os.path.join(results_dir, 'figures') |
| os.makedirs(lidar_fig_dir, exist_ok=True) |
|
|
| |
| fig = plt.figure(figsize=(6, 5)) |
| ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| ax.view_init(elev=30, azim=-115, roll=0) |
| for i, traj in enumerate(all_trajs): |
| plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
| plt.savefig(os.path.join(lidar_fig_dir, 'lidar_all_branches.png'), dpi=300) |
| plt.close() |
|
|
| |
| for i, traj in enumerate(all_trajs): |
| fig = plt.figure(figsize=(6, 5)) |
| ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| ax.view_init(elev=30, azim=-115, roll=0) |
| plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
| plt.savefig(os.path.join(lidar_fig_dir, f'lidar_branch_{i + 1}.png'), dpi=300) |
| plt.close() |