| from dataclasses import dataclass |
| import numpy as np |
| import torch |
|
|
| from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment |
|
|
|
|
| class AbAlignment: |
|
|
| def __init__(self, device = 'cpu', ncpu = 1): |
| |
| self.device = device |
| self.ncpu = ncpu |
| |
| def number_sequences(self, seqs, chain = 'H', fragmented = False): |
| if chain == 'HL': |
| numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented = fragmented, n_jobs = self.ncpu) |
| else: |
| assert chain == 'HL', 'Currently "Align==True" only works for paired sequences. \nPlease use paired sequences or Align=False.' |
| numbered_seqs, seqs, number_alignment = unpaired_msa_numbering( |
| seqs, chain = chain, fragmented = fragmented, n_jobs = self.ncpu |
| ) |
| |
| return numbered_seqs, seqs, number_alignment |
| |
| def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment): |
| |
| aligned_list = [ |
| create_alignment( |
| res_embed, numbered_seq, seq, number_alignment |
| ) for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs) |
| ] |
| aligned_encodings = np.concatenate([aligned_list], axis=0) |
| return aligned_encodings |
| |
| |
| def reformat_subsets( |
| self, |
| subset_list, |
| mode = 'seqcoding', |
| align = False, |
| numbered_seqs = None, |
| seqs = None, |
| number_alignment = None, |
| ): |
| |
| if mode in [ |
| 'seqcoding', |
| 'restore', |
| 'pseudo_log_likelihood', |
| 'confidence' |
| ]: |
| return np.concatenate(subset_list) |
| elif align: |
| subset_list = [ |
| self.align_encodings( |
| subset, |
| numbered_seqs[num*len(subset):(num+1)*len(subset)], |
| seqs[num*len(subset):(num+1)*len(subset)], |
| number_alignment |
| ) for num, subset in enumerate(subset_list) |
| ] |
| |
| subset = np.concatenate(subset_list) |
| |
| return aligned_results( |
| aligned_seqs = [''.join(alist) for alist in subset[:,:,-1]], |
| aligned_embeds = subset[:,:,:-1].astype(float), |
| number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values |
| ) |
| |
| elif not align: |
| return sum(subset_list, []) |
| else: |
| return np.concatenate(subset_list) |
| |
|
|
| @dataclass |
| class aligned_results(): |
| """ |
| Dataclass used to store output. |
| """ |
| |
| aligned_seqs: None |
| aligned_embeds: None |
| number_alignment: None |