FastSplatStyler / graph_networks /graph_transforms.py
incrl's picture
Initial Upload (attempt 2)
5b557cf verified
""" graph_transforms.py
This is the implementation of transforming a traditional CNN
to a SelectionConv-based graph CNN
So far this is just used for segmentation
"""
from copy import deepcopy
from typing import Dict, Iterable, OrderedDict, Tuple, Union #,Literal Only supported in Python 3.8+
from typing_extensions import Literal
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torchvision.models.segmentation.fcn import FCN
from selectionConv import SelectionConv
import pooling as P
def transform_network(network: nn.Module):
""" Transforms a neural network from a tensor based network to a graph based network
Parameters
----------
- network: the network to transform
Returns
-------
- the transformed network
"""
network = deepcopy(network)
if type(network) in __MAPPING:
with torch.no_grad():
return __MAPPING[type(network)].from_torch(network)
if not isinstance(network, nn.Module):
raise ValueError(f"Must be of type Module but got: {type(network)}")
if not list(network.children()):
raise NotImplementedError(f"{type(network)} is not implemented yet")
for name, child in network.named_children():
transformed_child = transform_network(child)
setattr(network, name, transformed_child)
return network
class GraphTracker:
""" a wrapper around the graph data for easily overwriting the forward function of existing modules.
Parameters
----------
- graph: the graph data
- level: the current depth the graph is being operated on
- x: the node data at the current level
"""
def __init__(self, graph, x=None, level=0):
self.graph = graph
self.x = graph.x if x is None else x
self.level = level
def from_x(self, x):
""" create the same graph with different node values"""
return GraphTracker(self.graph,x,level=self.level)
def edge_index(self):
return self.graph.edge_indexes[self.level]
def selections(self):
return self.graph.selections_list[self.level]
def interps(self):
if hasattr(self.graph,"interps_list"):
return self.graph.interps_list[self.level]
else:
return None
def cluster(self):
return self.graph.clusters[self.level]
def __iadd__(self, other):
self.x = self.x + other.x
return self
def __repr__(self):
return f"GraphTracker(x={tuple(self.x.shape)},level={self.level})"
def _single(pair, name):
""" converts a tuple into a single number
Parameters
----------
- pair: the potential pair of values
- name: the name of the values for more readable errors
Returns
-------
- the single value
"""
if isinstance(pair, int):
return pair
if not isinstance(pair, tuple):
raise ValueError(f"{name} must either be int or tuple but got: {type(pair)}")
if len(pair) != 2:
raise ValueError(f"{name} must be a 2-tuple but got: {pair}")
if pair[0] != pair[1]:
raise ValueError(f"{name} must be a square tuple")
return pair[0]
class SelModule(nn.Module):
""" A super class for all graph based modules to inherit from
"""
@classmethod
def from_torch(cls, network):
""" creates a new graph based module from an existing 2d based module and copies weights accordingly. Each child class should implement this method
Parameters
----------
- network: the existing 2d based module
Returns
-------
- the new graph based module
"""
raise NotImplementedError
class SelConv(SelModule, nn.modules.conv._ConvNd):
""" A wrapper class around the SelectionConv class that allows for easy
use in a transformed network
Parameters
----------
- in_channels: the number of incoming channels
- out_channels: the number of outgoing channels
- kernel_size: the size of the convolution kernel
- stride: the stride at which to perform convolution
- padding: the amount of padding to be used
- dilation: the dilation of the kernel
- groups: the groups of filters for the convolution
- bias: whether or not to include a bias
- padding_mode: the type of padding to be used
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]]=1,
padding: Union[int, Tuple[int, int]]=0,
dilation: Union[int, Tuple[int, int]]=1,
groups: int = 1,
bias: bool=True,
padding_mode: str='zeros',
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, (0,), groups, bias, padding_mode, **factory_kwargs)
self.single_stride = _single(stride, "stride")
if self.single_stride not in (1, 2):
raise NotImplementedError(f"Only strides of 1 and 2 are supported but got {stride}")
self.conv_operation = SelectionConv(
in_channels,
out_channels,
_single(kernel_size, "kernel_size"),
_single(dilation, "dilation"),
padding_mode,
)
def forward(self, inputs: GraphTracker):
x = self.conv_operation(inputs.x, inputs.edge_index(), inputs.selections(), inputs.interps())
ret = inputs.from_x(x)
if self.single_stride == 2:
x = P.stridePoolCluster(x, ret.cluster())
ret.x = x
ret.level += 1
return ret
@classmethod
def from_torch(cls, network):
ret = SelConv(
network.in_channels,
network.out_channels,
network.kernel_size,
network.stride,
network.padding,
network.dilation,
network.groups,
network.bias is not None,
network.padding_mode,
)
ret.conv_operation.copy_weights(network.weight, network.bias)
return ret
class SelMaxPool(SelModule):
""" A graph based max pool module
Parameters
----------
- kernel_size: the size of the maxpool kernel
"""
def __init__(self, kernel_size):
super().__init__()
self.kernel_size = kernel_size
def forward(self, inputs):
x = P.maxPoolKernel(inputs.x, inputs.edge_index(), inputs.selections(), inputs.cluster(), self.kernel_size)
ret = inputs.from_x(x)
ret.level += 1
return ret
@classmethod
def from_torch(cls, network):
ret = SelMaxPool(network.kernel_size)
return ret
class SelBatchNorm(SelModule):
""" A graph based BatchNorm module
"""
def __init__(self,num_features):
super().__init__()
self.bn = nn.BatchNorm1d(num_features)
#self.bn = SimpleBatchNorm()
def forward(self, inputs):
x = self.bn(inputs.x)
ret = inputs.from_x(x)
return ret
def copyBatchNorm(self,source):
self.bn.weight = source.weight
self.bn.bias = source.bias
self.bn.running_mean = source.running_mean
self.bn.running_var = source.running_var
self.bn.eps = source.eps
@classmethod
def from_torch(cls, network):
ret = SelBatchNorm(network.num_features)
ret.copyBatchNorm(network)
#ret.bn.set_values(network)
return ret
class SelReLU(SelModule):
""" A graph based ReLU module
Parameters
----------
- inplace: whether or not to perform relu in place
"""
def __init__(self, inplace=False):
super().__init__()
self.inplace = inplace
def forward(self, inputs):
if self.inplace:
inputs.x = F.relu(inputs.x, self.inplace)
return inputs
else:
x = F.relu(inputs.x, self.inplace)
ret = inputs.from_x(x)
return ret
@classmethod
def from_torch(cls, network):
return SelReLU(network.inplace)
class SelSequential(SelModule, nn.Sequential):
""" A graph based Sequential module
"""
@classmethod
def from_torch(cls, network: nn.Sequential):
return SelSequential(*map(transform_network, network))
class SelDropout(SelModule):
""" A graph based dropout module
"""
def forward(self, inputs):
return inputs
@classmethod
def from_torch(cls, network):
return SelDropout()
def sel_binlinear_interp(
inputs: GraphTracker,
up_or_down: Literal["up", "down"]="up",
) -> GraphTracker:
""" Performs bilinear interpolation as a single cluster step
Parameters
----------
- inputs: the input graph
- up_or_down: either "up" or "down" indicating if it is upsampling or downsampling
Returns
-------
- the interpolated graph
"""
supported_up_or_downs = ("up", "down")
if up_or_down not in supported_up_or_downs:
raise ValueError(f"up_or_down must either be 'up' or 'down' not: {up_or_down}")
ret = inputs.from_x(inputs.x)
dx = -1 if up_or_down == "up" else 1
ret.level += dx
cluster = ret.cluster()
up_edge_index = ret.edge_index()
#up_selections = ret.selections()
#ret.x = P.unpoolBilinear(ret.x, cluster, up_edge_index, up_selections)
up_interps = ret.interps()
ret.x = P.unpoolInterpolated(ret.x,cluster,up_edge_index,up_interps)
#ret.x = P.unpoolCluster(inputs.x, inputs.clusters[inputs.cluster_id])
return ret
def sel_interpolate(
inputs: GraphTracker,
target_level: int,
) -> GraphTracker:
""" interpolates a graph to a given cluster_id
Parameters
----------
- inputs: the input graph data
- target_cluster_id: the target cluster
Returns
-------
- the interpolated graph
"""
up_or_down = "up" if target_level < inputs.level else "down"
while inputs.level != target_level:
inputs = sel_binlinear_interp(inputs, up_or_down)
return inputs
class SelSimpleSegmentationModel(SelModule):
""" A graph version of the simple segmentation model defined in torchvision's segmentation model. This is needed since the interpolate function we use needs different parameters than what is used in torch.
"""
__constants__ = ["aux_classifier"]
def __init__(self, backbone, classifier, aux_classifier = None):
super().__init__()
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x: GraphTracker) -> Dict[str, GraphTracker]:
starting_level = x.level
features = self.backbone(x)
result = OrderedDict()
x = features["out"]
x = self.classifier(x)
x = sel_interpolate(x, starting_level)
result["out"] = x
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = sel_interpolate(x, starting_level)
result["aux"] = x
return result
@classmethod
def from_torch(cls, network):
ret = SelSimpleSegmentationModel(
backbone = transform_network(network.backbone),
classifier = transform_network(network.classifier),
aux_classifier=transform_network(network.aux_classifier) if network.aux_classifier is not None else None,
)
return ret
__MAPPING = {
nn.Conv2d: SelConv,
nn.BatchNorm2d: SelBatchNorm,
nn.ReLU: SelReLU,
nn.Sequential: SelSequential,
nn.Dropout: SelDropout,
nn.MaxPool2d: SelMaxPool,
FCN: SelSimpleSegmentationModel,
}