| | import torch |
| | from location_encoder import get_neural_network, get_positional_encoding, LocationEncoder |
| |
|
| |
|
| | def get_satclip_loc_encoder(ckpt_path, device): |
| | ckpt = torch.load(ckpt_path,map_location=device) |
| | hp = ckpt['hyper_parameters'] |
| |
|
| | posenc = get_positional_encoding( |
| | hp['le_type'], |
| | hp['legendre_polys'], |
| | hp['harmonics_calculation'], |
| | hp['min_radius'], |
| | hp['max_radius'], |
| | hp['frequency_num'] |
| | ) |
| | |
| | nnet = get_neural_network( |
| | hp['pe_type'], |
| | posenc.embedding_dim, |
| | hp['embed_dim'], |
| | hp['capacity'], |
| | hp['num_hidden_layers'] |
| | ) |
| |
|
| | |
| | state_dict = ckpt['state_dict'] |
| | state_dict = {k[k.index('nnet'):]:state_dict[k] |
| | for k in state_dict.keys() if 'nnet' in k} |
| | |
| | loc_encoder = LocationEncoder(posenc, nnet).double() |
| | loc_encoder.load_state_dict(state_dict) |
| | loc_encoder.eval() |
| |
|
| | return loc_encoder |
| | |