| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from functools import partial
|
| import numpy as np
|
| import torch
|
| from tqdm import tqdm
|
| import math, random
|
|
|
|
|
|
|
| def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
|
| N,C,H,W = data_vecs.shape
|
| assert N == 1, 'only support singe image tensor'
|
|
|
| data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
|
|
| data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy()
|
| km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300)
|
| pred = km.fit_predict(data_vecs_np)
|
| cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device)
|
| id_maps = cluster_ids_x.reshape(1,1,H,W).long()
|
| if need_layer_masks:
|
| one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
|
| cluster_mask = one_hot_labels.permute(0,3,1,2)
|
| return cluster_mask
|
| return id_maps
|
|
|
|
|
| def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
|
| N,C,H,W = data_vecs.shape
|
| assert N == 1, 'only support singe image tensor'
|
|
|
|
|
| data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
|
|
|
|
| cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
|
| tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
|
| id_maps = cluster_ids_x.reshape(1,1,H,W)
|
| if need_layer_masks:
|
| one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
|
| cluster_mask = one_hot_labels.permute(0,3,1,2)
|
| return cluster_mask
|
| return id_maps
|
|
|
|
|
| def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False):
|
| N,C,H,W = data_vecs.shape
|
| sample_list = []
|
| for idx in range(N):
|
| if use_sklearn_kmeans:
|
| cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
|
| else:
|
| cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
|
| sample_list.append(cluster_mask)
|
| return torch.cat(sample_list, dim=0)
|
|
|
|
|
| def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20):
|
| N,C,H,W = data_vecs.shape
|
| data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
| cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
|
| tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
|
| return cluster_centers
|
|
|
|
|
| def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'):
|
| N,C,H,W = data_tensor.shape
|
| centroid_list = []
|
| for idx in range(N):
|
| cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric)
|
| centroid_list.append(cluster_centers)
|
|
|
| batch_centroids = torch.stack(centroid_list, dim=0)
|
| data_vecs = data_tensor.flatten(2)
|
|
|
| AtB = torch.matmul(batch_centroids, data_vecs)
|
| AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1))
|
| BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
|
| diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1)
|
| diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1)
|
| A2 = diag_A.unsqueeze(2).repeat(1,1,H*W)
|
| B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1)
|
| distance_map = A2 - 2*AtB + B2
|
| values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True)
|
| cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map))
|
| cluster_mask = cluster_mask.view(N,n_clusters,H,W)
|
| return cluster_mask
|
|
|
|
|
|
|
| '''
|
| resource from github: https://github.com/subhadarship/kmeans_pytorch
|
| '''
|
|
|
|
|
| def initialize(X, num_clusters):
|
| """
|
| initialize cluster centers
|
| :param X: (torch.tensor) matrix
|
| :param num_clusters: (int) number of clusters
|
| :return: (np.array) initial state
|
| """
|
| np.random.seed(1)
|
| num_samples = len(X)
|
| indices = np.random.choice(num_samples, num_clusters, replace=False)
|
| initial_state = X[indices]
|
| return initial_state
|
|
|
|
|
| def kmeans(
|
| X,
|
| num_clusters,
|
| distance='euclidean',
|
| cluster_centers=[],
|
| tol=1e-4,
|
| tqdm_flag=True,
|
| iter_limit=0,
|
| device=torch.device('cpu'),
|
| gamma_for_soft_dtw=0.001
|
| ):
|
| """
|
| perform kmeans
|
| :param X: (torch.tensor) matrix
|
| :param num_clusters: (int) number of clusters
|
| :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
| :param tol: (float) threshold [default: 0.0001]
|
| :param device: (torch.device) device [default: cpu]
|
| :param tqdm_flag: Allows to turn logs on and off
|
| :param iter_limit: hard limit for max number of iterations
|
| :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
|
| :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
|
| """
|
| if tqdm_flag:
|
| print(f'running k-means on {device}..')
|
|
|
| if distance == 'euclidean':
|
| pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
|
| elif distance == 'cosine':
|
| pairwise_distance_function = partial(pairwise_cosine, device=device)
|
| else:
|
| raise NotImplementedError
|
|
|
|
|
| X = X.float()
|
|
|
|
|
| X = X.to(device)
|
|
|
|
|
| if type(cluster_centers) == list:
|
| initial_state = initialize(X, num_clusters)
|
| else:
|
| if tqdm_flag:
|
| print('resuming')
|
|
|
| initial_state = cluster_centers
|
| dis = pairwise_distance_function(X, initial_state)
|
| choice_points = torch.argmin(dis, dim=0)
|
| initial_state = X[choice_points]
|
| initial_state = initial_state.to(device)
|
|
|
| iteration = 0
|
| if tqdm_flag:
|
| tqdm_meter = tqdm(desc='[running kmeans]')
|
| while True:
|
|
|
| dis = pairwise_distance_function(X, initial_state)
|
|
|
| choice_cluster = torch.argmin(dis, dim=1)
|
|
|
| initial_state_pre = initial_state.clone()
|
|
|
| for index in range(num_clusters):
|
| selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
|
|
|
| selected = torch.index_select(X, 0, selected)
|
|
|
|
|
| if selected.shape[0] == 0:
|
| selected = X[torch.randint(len(X), (1,))]
|
|
|
| initial_state[index] = selected.mean(dim=0)
|
|
|
| center_shift = torch.sum(
|
| torch.sqrt(
|
| torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
|
| ))
|
|
|
|
|
| iteration = iteration + 1
|
|
|
|
|
| if tqdm_flag:
|
| tqdm_meter.set_postfix(
|
| iteration=f'{iteration}',
|
| center_shift=f'{center_shift ** 2:0.6f}',
|
| tol=f'{tol:0.6f}'
|
| )
|
| tqdm_meter.update()
|
| if center_shift ** 2 < tol:
|
| break
|
| if iter_limit != 0 and iteration >= iter_limit:
|
|
|
| break
|
|
|
| return choice_cluster.to(device), initial_state.to(device)
|
|
|
|
|
| def kmeans_predict(
|
| X,
|
| cluster_centers,
|
| distance='euclidean',
|
| device=torch.device('cpu'),
|
| gamma_for_soft_dtw=0.001,
|
| tqdm_flag=True
|
| ):
|
| """
|
| predict using cluster centers
|
| :param X: (torch.tensor) matrix
|
| :param cluster_centers: (torch.tensor) cluster centers
|
| :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
| :param device: (torch.device) device [default: 'cpu']
|
| :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
|
| :return: (torch.tensor) cluster ids
|
| """
|
| if tqdm_flag:
|
| print(f'predicting on {device}..')
|
|
|
| if distance == 'euclidean':
|
| pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
|
| elif distance == 'cosine':
|
| pairwise_distance_function = partial(pairwise_cosine, device=device)
|
| elif distance == 'soft_dtw':
|
| sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
|
| pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
|
| else:
|
| raise NotImplementedError
|
|
|
|
|
| X = X.float()
|
|
|
|
|
| X = X.to(device)
|
|
|
| dis = pairwise_distance_function(X, cluster_centers)
|
| choice_cluster = torch.argmin(dis, dim=1)
|
|
|
| return choice_cluster.cpu()
|
|
|
|
|
| def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
|
| if tqdm_flag:
|
| print(f'device is :{device}')
|
|
|
|
|
| data1, data2 = data1.to(device), data2.to(device)
|
|
|
|
|
| A = data1.unsqueeze(dim=1)
|
|
|
|
|
| B = data2.unsqueeze(dim=0)
|
|
|
| dis = (A - B) ** 2.0
|
|
|
| dis = dis.sum(dim=-1).squeeze()
|
| return dis
|
|
|
|
|
| def pairwise_cosine(data1, data2, device=torch.device('cpu')):
|
|
|
| data1, data2 = data1.to(device), data2.to(device)
|
|
|
|
|
| A = data1.unsqueeze(dim=1)
|
|
|
|
|
| B = data2.unsqueeze(dim=0)
|
|
|
|
|
| A_normalized = A / A.norm(dim=-1, keepdim=True)
|
| B_normalized = B / B.norm(dim=-1, keepdim=True)
|
|
|
| cosine = A_normalized * B_normalized
|
|
|
|
|
| cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
|
| return cosine_dis |