| | import numpy as np |
| | import pickle |
| | from torch.utils.data import Dataset, DataLoader |
| | import os |
| | import torch |
| | from copy import deepcopy |
| | from blimpy import Waterfall |
| | from tqdm import tqdm |
| | from copy import deepcopy |
| | from sigpyproc.readers import FilReader |
| | from torch import nn |
| |
|
| |
|
| | def load_pickled_data(file_path): |
| | with open(file_path, 'rb') as f: |
| | data = pickle.load(f) |
| | return data |
| |
|
| | |
| | class CustomDataset(Dataset): |
| | def __init__(self, data_dir, bit8=False, transform=None): |
| | self.data_dir = data_dir |
| | self.transform = transform |
| | self.images = [] |
| | self.labels = [] |
| | self.classes = os.listdir(data_dir) |
| | self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
| | self.bit8 = bit8 |
| | |
| | for cls in self.classes: |
| | class_dir = os.path.join(data_dir, cls) |
| | for image_name in os.listdir(class_dir): |
| | image_path = os.path.join(class_dir, image_name) |
| | self.images.append(image_path) |
| | self.labels.append(self.class_to_idx[cls]) |
| |
|
| | def __len__(self): |
| | return len(self.images) |
| |
|
| | def __getitem__(self, idx): |
| | image_path = self.images[idx] |
| | label = self.labels[idx] |
| | |
| | image = load_pickled_data(image_path) |
| | if self.transform is not None: |
| | if self.bit8 == True: |
| | new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) |
| | else: |
| | new_image = self.transform(torch.from_numpy(image['data'])) |
| | |
| | return new_image, label |
| |
|
| | |
| | class CustomDataset_Masked(Dataset): |
| | def __init__(self, data_dir, transform=None): |
| | self.data_dir = data_dir |
| | self.transform = transform |
| | self.images = [] |
| | self.labels = [] |
| | self.classes = os.listdir(data_dir) |
| | self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
| |
|
| | |
| | for cls in self.classes: |
| | class_dir = os.path.join(data_dir, cls) |
| | for image_name in os.listdir(class_dir): |
| | image_path = os.path.join(class_dir, image_name) |
| | self.images.append(image_path) |
| | self.labels.append(self.class_to_idx[cls]) |
| | |
| | def __len__(self): |
| | return len(self.images) |
| |
|
| | def __getitem__(self, idx): |
| | image_path = self.images[idx] |
| | |
| | label = self.labels[idx] |
| | |
| | image = load_pickled_data(image_path) |
| | if self.transform is not None: |
| | if image['burst'].max() ==0: |
| | new_burst = torch.from_numpy(image['burst']) |
| | else: |
| | new_burst = torch.from_numpy(image['burst']/image['burst'].max()) |
| | ind = new_burst > 0.1 |
| | ind_not = new_burst <= 0.1 |
| | new_burst[ind] = 1 |
| | new_burst[ind_not] = 0 |
| | new_image = self.transform(torch.from_numpy(image['data'].data)) |
| | new_burst_arr = torch.zeros_like(new_image) |
| | new_burst_arr[ 0, :,:] = new_burst |
| | new_burst_arr[ 1, :,:] = new_burst |
| | new_burst_arr[ 2, :,:] = new_burst |
| | return new_image, label, new_burst_arr |
| |
|
| | |
| | class TestingDataset(Dataset): |
| | def __init__(self, data_dir, bit8=False, transform=None): |
| | self.data_dir = data_dir |
| | self.transform = transform |
| | self.images = [] |
| | self.labels = [] |
| | self.classes = os.listdir(data_dir) |
| | self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
| | self.bit8 = bit8 |
| | |
| | for cls in self.classes: |
| | class_dir = os.path.join(data_dir, cls) |
| | for image_name in os.listdir(class_dir): |
| | image_path = os.path.join(class_dir, image_name) |
| | self.images.append(image_path) |
| | self.labels.append(self.class_to_idx[cls]) |
| |
|
| | def __len__(self): |
| | return len(self.images) |
| |
|
| | def __getitem__(self, idx): |
| | image_path = self.images[idx] |
| | label = self.labels[idx] |
| | |
| | image = load_pickled_data(image_path) |
| | params = image['params'] |
| | if self.transform is not None: |
| | params = image['params'] |
| | if self.bit8 == True: |
| | new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) |
| | else: |
| | new_image = self.transform(torch.from_numpy(image['data'])) |
| | params['labels'] = label |
| | return new_image, (label, params['dm'], params['freq_ref'], params['snr'], params['boxcard']) |
| |
|
| | |
| | class SearchDataset(Dataset): |
| | def __init__(self, data_dir, transform=None, pickle_data=False): |
| | self.window_size = 2048 |
| | |
| | if pickle_data: |
| | with open(data_dir, 'rb') as f: |
| | self.d = pickle.load(f) |
| | self.header = self.d['header'] |
| | self.images = self.crop(self.d['data'][:,0,:], self.window_size) |
| | else: |
| | self.obs = Waterfall(data_dir, max_load = 50) |
| | self.header = self.obs.header |
| | self.images = self.crop(self.obs.data[:,0,:], self.window_size) |
| | self.transform = transform |
| | self.SEC_PER_DAY = 86400 |
| | |
| | def crop(self, data, window_size = 2048): |
| | n_samp = data.shape[0]//window_size |
| | new_data = np.zeros((n_samp, window_size, 192 )) |
| | for i in range(n_samp): |
| | new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] |
| | return new_data |
| | |
| | def __len__(self): |
| | return self.images.shape[0] |
| | def __getitem__(self, idx): |
| | data = self.images[idx, :, :].T |
| | tindex = idx * self.window_size |
| | time = self.header['tsamp'] * tindex / self.SEC_PER_DAY + self.header['tstart'] |
| | if self.transform is not None: |
| | new_image = self.transform(data) |
| | return new_image, idx |
| |
|
| | |
| | class SearchDataset_Sigproc(Dataset): |
| | def __init__(self, data_dir, transform=None): |
| | self.window_size = 2048 |
| | fil = FilReader(data_dir) |
| | self.header = fil.header |
| | |
| | read_data = fil.read_block(0, fil.header.nsamples)[:,1024:-1024] |
| | read_data = np.swapaxes(read_data, 0,-1) |
| | self.images = self.crop(read_data, self.window_size) |
| | self.transform = transform |
| | self.SEC_PER_DAY = 86400 |
| | |
| | def crop(self, data, window_size = 2048): |
| | n_samp = data.shape[0]//window_size |
| | new_data = np.zeros((n_samp, window_size, 192 )) |
| | for i in range(n_samp): |
| | new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] |
| | return new_data |
| | |
| | def __len__(self): |
| | return self.images.shape[0] |
| | |
| | def __getitem__(self, idx): |
| | data = self.images[idx, :, :].T |
| | tindex = idx * self.window_size |
| | time = self.header.tsamp * tindex / self.SEC_PER_DAY + self.header.tstart |
| | if self.transform is not None: |
| | new_image = self.transform(torch.from_numpy(data)) |
| | return new_image, idx |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def renorm(data): |
| | mean = torch.mean(data) |
| | std = torch.std(data) |
| | |
| | standardized_data = (data - mean) / std |
| | return standardized_data |
| |
|
| | def transform(data): |
| | copy_data = data.detach().clone() |
| | rms = torch.std(data) |
| | mean = torch.mean(data) |
| | masks_rms = [-1, 5] |
| | new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1])) |
| | new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10)) |
| | for i in range(1, len(masks_rms)+1): |
| | scale = masks_rms[i-1] |
| | copy_data = data.detach().clone() |
| | if scale < 0: |
| | ind = copy_data < abs(scale) * rms + mean |
| | copy_data[ind] = 0 |
| | else: |
| | ind = copy_data > (scale) * rms + mean |
| | copy_data[ind] = 0 |
| | new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10)) |
| | new_data = new_data.type(torch.float32) |
| | slices = torch.chunk(new_data, 8, dim=-1) |
| | new_data = torch.stack(slices, dim=1) |
| | new_data = new_data.view(-1, new_data.size(2), new_data.size(3)) |
| | return new_data |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def renorm_batched(data): |
| | mins = torch.amin(data, (-2, -1)) |
| | mins = mins.unsqueeze(1).unsqueeze(2) |
| | mins = mins.expand(data.shape[0], 192, 2048) |
| | shifted = data - mins |
| | maxs = torch.amax(shifted, (-2, -1)) |
| | maxs = maxs.unsqueeze(1).unsqueeze(2) |
| | maxs = maxs.expand(data.shape[0], 192, 2048) |
| | shifted = shifted/maxs |
| | return shifted |
| | |
| |
|
| | def transform_mask(data): |
| | copy_data = deepcopy(data) |
| | shift = copy_data - copy_data.min() |
| | normalized_data = shift / shift.max() |
| | new_data = np.zeros((3, data.shape[0], data.shape[1])) |
| | for i in range(3): |
| | new_data[i,:,:] = normalized_data |
| | new_data = new_data.astype(np.float32) |
| | return new_data |
| |
|
| |
|
| | |
| | def Convert_ONNX(model, saveloc, input_data_mock): |
| | print("Saving to ONNX") |
| | |
| | model.eval() |
| |
|
| | |
| | dummy_input = torch.autograd.Variable(input_data_mock) |
| |
|
| | |
| | torch.onnx.export(model, |
| | dummy_input, |
| | saveloc, |
| | input_names = ['modelInput'], |
| | output_names = ['modelOutput'], |
| | dynamic_axes={'modelInput' : {0 : 'batch_size'}, |
| | 'modelOutput' : {0 : 'batch_size'}} ) |
| | print(" ") |
| | print('Model has been converted to ONNX') |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|