| import sys |
| import os |
| import torch |
| import pytorch_lightning as pl |
| from .ema import EMA |
| import itertools |
| from .utils import plot_lidar |
| import matplotlib.pyplot as plt |
|
|
| class BranchInterpolantTrain(pl.LightningModule): |
| def __init__( |
| self, |
| flow_matcher, |
| args, |
| skipped_time_points: list = None, |
| ot_sampler=None, |
| |
| state_cost=None, |
| data_manifold_metric=None, |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
| self.args = args |
| |
| self.flow_matcher = flow_matcher |
| |
| |
| self.geopath_nets = flow_matcher.geopath_nets |
| self.branches = len(self.geopath_nets) |
| self.metric_clusters = args.metric_clusters |
| |
| self.ot_sampler = ot_sampler |
| self.skipped_time_points = skipped_time_points if skipped_time_points else [] |
| self.optimizer_name = args.geopath_optimizer |
| self.lr = args.geopath_lr |
| self.weight_decay = args.geopath_weight_decay |
| self.args = args |
| self.multiply_validation = 4 |
|
|
| self.first_loss = None |
| self.timesteps = None |
| self.computing_reference_loss = False |
| |
| |
| self.state_cost = state_cost |
| self.data_manifold_metric = data_manifold_metric |
| self.whiten = args.whiten |
|
|
| def forward(self, x0, x1, t, branch_idx): |
| |
| return self.geopath_nets[branch_idx](x0, x1, t) |
|
|
| def on_train_start(self): |
| self.first_loss = self.compute_initial_loss() |
| print("first loss") |
| print(self.first_loss) |
|
|
| |
| def compute_initial_loss(self): |
| |
| for net in self.geopath_nets: |
| net.train(mode=False) |
| |
| total_loss = 0 |
| total_count = 0 |
| with torch.enable_grad(): |
| self.t_val = [] |
| for i in range( |
| self.trainer.datamodule.num_timesteps - len(self.skipped_time_points) |
| ): |
| self.t_val.append( |
| torch.rand( |
| self.trainer.datamodule.batch_size * self.multiply_validation, |
| requires_grad=True, |
| ) |
| ) |
| self.computing_reference_loss = True |
| with torch.no_grad(): |
| old_alpha = self.flow_matcher.alpha |
| self.flow_matcher.alpha = 0 |
| for batch in self.trainer.datamodule.train_dataloader(): |
| |
| if isinstance(batch, (list, tuple)): |
| batch = batch[0] |
| if isinstance(batch, dict) and "train_samples" in batch: |
| main_batch_init = batch["train_samples"] |
| metric_batch_init = batch["metric_samples"] |
| if isinstance(main_batch_init, tuple): |
| main_batch_init = main_batch_init[0] |
| if isinstance(metric_batch_init, tuple): |
| metric_batch_init = metric_batch_init[0] |
| else: |
| main_batch_init = batch |
| metric_batch_init = [] |
| |
| self.timesteps = torch.linspace( |
| 0.0, 1.0, len(main_batch_init["x0"]) |
| ).tolist() |
| |
| loss = self._compute_loss( |
| main_batch_init, |
| metric_batch_init, |
| ) |
| print("initial loss") |
| print(loss) |
| |
| if not (torch.isnan(loss) or torch.isinf(loss)): |
| total_loss += loss.item() |
| total_count += 1 |
| self.flow_matcher.alpha = old_alpha |
| |
| self.computing_reference_loss = False |
| |
| |
| for net in self.geopath_nets: |
| net.train(mode=True) |
| return total_loss / total_count if total_count > 0 else 1.0 |
|
|
| def _compute_loss(self, main_batch, metric_samples_batch=None): |
| |
| x0s = [main_batch["x0"][0]] |
| w0s = [main_batch["x0"][1]] |
| |
| x1s_list = [] |
| w1s_list = [] |
| |
| if self.branches > 1: |
| for i in range(self.branches): |
| x1s_list.append([main_batch[f"x1_{i+1}"][0]]) |
| w1s_list.append([main_batch[f"x1_{i+1}"][1]]) |
| else: |
| x1s_list.append([main_batch["x1"][0]]) |
| w1s_list.append([main_batch["x1"][1]]) |
| |
| if self.args.manifold: |
| |
| if self.metric_clusters == 7: |
| |
| branch_sample_pairs = [ |
| (metric_samples_batch[0], metric_samples_batch[1]), |
| (metric_samples_batch[0], metric_samples_batch[2]), |
| (metric_samples_batch[0], metric_samples_batch[3]), |
| (metric_samples_batch[0], metric_samples_batch[4]), |
| (metric_samples_batch[0], metric_samples_batch[5]), |
| (metric_samples_batch[0], metric_samples_batch[6]), |
| ] |
| elif self.metric_clusters == 4: |
| branch_sample_pairs = [ |
| (metric_samples_batch[0], metric_samples_batch[1]), |
| (metric_samples_batch[0], metric_samples_batch[2]), |
| (metric_samples_batch[0], metric_samples_batch[3]), |
| ] |
| elif self.metric_clusters == 3: |
| branch_sample_pairs = [ |
| (metric_samples_batch[0], metric_samples_batch[1]), |
| (metric_samples_batch[0], metric_samples_batch[2]), |
| ] |
| elif self.metric_clusters == 2 and self.branches == 2: |
| branch_sample_pairs = [ |
| (metric_samples_batch[0], metric_samples_batch[0]), |
| (metric_samples_batch[0], metric_samples_batch[1]), |
| ] |
| elif self.metric_clusters == 2: |
| |
| |
| branch_sample_pairs = [ |
| (metric_samples_batch[0], metric_samples_batch[1]) |
| ] * self.branches |
| else: |
| branch_sample_pairs = [ |
| (metric_samples_batch[0], metric_samples_batch[1]), |
| ] |
| """samples0, samples1, samples2 = ( |
| metric_samples_batch[0], |
| metric_samples_batch[1], |
| metric_samples_batch[2] |
| )""" |
| |
| assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" |
| |
| |
| loss = 0 |
| velocities = [] |
| for branch_idx in range(self.branches): |
| |
| ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx) |
| |
| for i in range(len(ts)): |
| |
| if self.args.manifold: |
| start_samples, end_samples = branch_sample_pairs[branch_idx] |
| |
| samples = torch.cat([start_samples, end_samples], dim=0) |
| |
| |
| vel, _, _ = self.data_manifold_metric.calculate_velocity( |
| xts[i], uts[i], samples, i |
| ) |
| else: |
| vel = torch.sqrt((uts[i]**2).sum(dim =-1) + self.state_cost(xts[i])) |
| |
| |
| velocities.append(vel) |
| |
| velocity_loss = torch.mean(torch.cat(velocities) ** 2) |
| |
| self.log( |
| "BranchPathNet/mean_velocity_geopath", |
| velocity_loss, |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| ) |
| |
| return velocity_loss |
|
|
| def _process_flow(self, x0s, x1s, branch_idx): |
| ts, xts, uts = [], [], [] |
| t_start = self.timesteps[0] |
| i_start = 0 |
|
|
| for i, (x0, x1) in enumerate(zip(x0s, x1s)): |
| x0, x1 = torch.squeeze(x0), torch.squeeze(x1) |
| if self.trainer.validating or self.computing_reference_loss: |
| repeat_tuple = (self.multiply_validation, 1) + (1,) * ( |
| len(x0.shape) - 2 |
| ) |
| x0 = x0.repeat(repeat_tuple) |
| x1 = x1.repeat(repeat_tuple) |
|
|
| if self.ot_sampler is not None: |
| x0, x1 = self.ot_sampler.sample_plan( |
| x0, |
| x1, |
| replace=True, |
| ) |
| if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]: |
| t_start_next = self.timesteps[i + 2] |
| else: |
| t_start_next = self.timesteps[i + 1] |
|
|
| t = None |
| if self.trainer.validating or self.computing_reference_loss: |
| t = self.t_val[i] |
|
|
| t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow( |
| x0, x1, t_start, t_start_next, branch_idx, training_geopath_net=True, t=t |
| ) |
| ts.append(t) |
| xts.append(xt) |
| uts.append(ut) |
| t_start = t_start_next |
|
|
| return ts, xts, uts |
|
|
| def training_step(self, batch, batch_idx): |
| |
| if isinstance(batch, (list, tuple)): |
| batch = batch[0] |
| if isinstance(batch, dict) and "train_samples" in batch: |
| main_batch = batch["train_samples"] |
| metric_batch = batch["metric_samples"] |
| if isinstance(main_batch, tuple): |
| main_batch = main_batch[0] |
| if isinstance(metric_batch, tuple): |
| metric_batch = metric_batch[0] |
| else: |
| |
| main_batch = batch.get("train_samples", batch) |
| metric_batch = batch.get("metric_samples", []) |
| |
| |
| if batch_idx == 0: |
| print(f"DEBUG batch type: {type(batch)}") |
| if isinstance(batch, dict): |
| print(f"DEBUG batch keys: {batch.keys()}") |
| print(f"DEBUG train_samples type: {type(batch.get('train_samples'))}") |
| if isinstance(batch.get("train_samples"), dict): |
| print(f"DEBUG train_samples keys: {batch['train_samples'].keys()}") |
| print(f"DEBUG x0 type: {type(batch['train_samples'].get('x0'))}") |
| if 'x0' in batch['train_samples']: |
| x0_item = batch['train_samples']['x0'] |
| print(f"DEBUG x0 structure: {type(x0_item)}") |
| if isinstance(x0_item, (list, tuple)): |
| print(f"DEBUG x0 length: {len(x0_item)}") |
| if len(x0_item) > 0: |
| print(f"DEBUG x0[0] shape: {x0_item[0].shape if hasattr(x0_item[0], 'shape') else 'no shape'}") |
| print(f"DEBUG main_batch type: {type(main_batch)}") |
| if isinstance(main_batch, dict): |
| print(f"DEBUG main_batch keys: {main_batch.keys()}") |
| |
| self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
| tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) |
| |
| if self.first_loss: |
| tangential_velocity_loss = tangential_velocity_loss / self.first_loss |
| |
| self.log( |
| "BranchPathNet/mean_geopath_geopath", |
| (self.flow_matcher.geopath_net_output.abs().mean()), |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| ) |
| |
| self.log( |
| "BranchPathNet/train_loss_geopath", |
| tangential_velocity_loss, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| logger=True, |
| ) |
| |
| return tangential_velocity_loss |
|
|
| def validation_step(self, batch, batch_idx): |
| |
| if isinstance(batch, (list, tuple)): |
| batch = batch[0] |
| if isinstance(batch, dict) and "val_samples" in batch: |
| main_batch = batch["val_samples"] |
| metric_batch = batch["metric_samples"] |
| if isinstance(main_batch, tuple): |
| main_batch = main_batch[0] |
| if isinstance(metric_batch, tuple): |
| metric_batch = metric_batch[0] |
| else: |
| |
| main_batch = batch.get("val_samples", batch) |
| metric_batch = batch.get("metric_samples", []) |
|
|
| self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
| tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) |
| if self.first_loss: |
| tangential_velocity_loss = tangential_velocity_loss / self.first_loss |
| |
| self.log( |
| "BranchPathNet/val_loss_geopath", |
| tangential_velocity_loss, |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| logger=True, |
| ) |
| return tangential_velocity_loss |
| |
| |
| def test_step(self, batch, batch_idx): |
| |
| if isinstance(batch, dict): |
| main_batch = batch["test_samples"] |
| metric_batch = batch["metric_samples"] |
| |
| if isinstance(main_batch, tuple): |
| main_batch = main_batch[0] |
| if isinstance(metric_batch, tuple): |
| metric_batch = metric_batch[0] |
| else: |
| |
| main_batch = batch[0][0] |
| metric_batch = batch[1][0] |
| |
| x0 = main_batch["x0"][0] |
| cloud_points = main_batch["dataset"][0] |
| |
| x0 = x0.to(self.device) |
| cloud_points = cloud_points.to(self.device) |
|
|
| t_vals = [0.25, 0.5, 0.75] |
| t_labels = ["t=1/4", "t=1/2", "t=3/4"] |
|
|
| colors = { |
| "x0": "#4D176C", |
| "t=1/4": "#5C3B9D", |
| "t=1/2": "#6172B9", |
| "t=3/4": "#AC4E51", |
| "x1": "#771F4F", |
| } |
|
|
| |
| if self.whiten: |
| cloud_points = torch.tensor( |
| self.trainer.datamodule.scaler.inverse_transform(cloud_points.cpu().numpy()) |
| ) |
|
|
| for i in range(self.branches): |
| geopath = self.geopath_nets[i] |
| x1_key = f"x1_{i + 1}" |
| if x1_key not in main_batch: |
| print(f"Skipping branch {i + 1}: no final distribution {x1_key}") |
| continue |
|
|
| x1 = main_batch[x1_key][0].to(self.device) |
| print(x1.shape) |
| print(x0.shape) |
| interpolated_points = [] |
| with torch.no_grad(): |
| for t_scalar in t_vals: |
| t_tensor = torch.full((x0.shape[0], 1), t_scalar, device=self.device) |
| xt = geopath(x0, x1, t_tensor).cpu() |
| if self.whiten: |
| xt = torch.tensor( |
| self.trainer.datamodule.scaler.inverse_transform(xt.numpy()) |
| ) |
| interpolated_points.append(xt) |
|
|
| if self.whiten: |
| x0_plot = torch.tensor( |
| self.trainer.datamodule.scaler.inverse_transform(x0.cpu().numpy()) |
| ) |
| x1_plot = torch.tensor( |
| self.trainer.datamodule.scaler.inverse_transform(x1.cpu().numpy()) |
| ) |
| else: |
| x0_plot = x0.cpu() |
| x1_plot = x1.cpu() |
|
|
| |
| fig = plt.figure(figsize=(6, 5)) |
| ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| ax.view_init(elev=30, azim=-115, roll=0) |
| plot_lidar(ax, cloud_points) |
|
|
| |
| ax.scatter( |
| x0_plot[:, 0], x0_plot[:, 1], x0_plot[:, 2], |
| s=15, alpha=1.0, color=colors["x0"], label="x₀", depthshade=True, |
| edgecolors="white", |
| linewidths=0.3 |
| ) |
|
|
| |
| for xt, t_label in zip(interpolated_points, t_labels): |
| ax.scatter( |
| xt[:, 0], xt[:, 1], xt[:, 2], |
| s=15, alpha=1.0, color=colors[t_label], label=t_label, depthshade=True, |
| edgecolors="white", |
| linewidths=0.3 |
| ) |
|
|
| |
| ax.scatter( |
| x1_plot[:, 0], x1_plot[:, 1], x1_plot[:, 2], |
| s=15, alpha=1.0, color=colors["x1"], label="x₁", depthshade=True, |
| edgecolors="white", |
| linewidths=0.3 |
| ) |
|
|
| ax.legend() |
| |
| |
| run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| figures_dir = os.path.join(results_dir, 'figures') |
| os.makedirs(figures_dir, exist_ok=True) |
| |
| save_path = f"{figures_dir}/lidar_geopath_branch_{i+1}.png" |
| plt.savefig(save_path, dpi=300) |
| plt.close() |
|
|
| def optimizer_step(self, *args, **kwargs): |
| super().optimizer_step(*args, **kwargs) |
| for net in self.geopath_nets: |
| if isinstance(net, EMA): |
| net.update_ema() |
|
|
| def configure_optimizers(self): |
| if self.optimizer_name == "adam": |
| optimizer = torch.optim.Adam( |
| itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr |
| ) |
| elif self.optimizer_name == "adamw": |
| optimizer = torch.optim.AdamW( |
| itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr |
| ) |
| return optimizer |
|
|