| import string, re |
| import numpy as np |
|
|
|
|
| def res_to_list(logits, seq): |
| return logits[:len(seq)] |
|
|
| def res_to_seq(a, mode='mean'): |
| """ |
| Function for how we go from n_values for each amino acid to n_values for each sequence. |
| |
| We leave out padding tokens. |
| """ |
| |
| if mode=='sum': |
| return a[0:(int(a[-1]))].sum() |
| |
| elif mode=='mean': |
| return a[0:(int(a[-1]))].mean() |
| |
| elif mode=='restore': |
| return a[0][0:(int(a[-1]))] |
|
|
| def get_number_alignment(numbered_seqs): |
| """ |
| Creates a number alignment from the anarci results. |
| """ |
| import pandas as pd |
|
|
| alist = [pd.DataFrame(aligned_seq, columns = [0,1,'resi']) for aligned_seq in numbered_seqs] |
| unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0) |
| max_alignment = get_max_alignment() |
| |
| return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]] |
|
|
| def get_max_alignment(): |
| """ |
| Create maximum possible alignment for sorting |
| """ |
| import pandas as pd |
|
|
| sortlist = [[("<", "")]] |
| for num in range(1, 128+1): |
| if num in [33,61,112]: |
| for char in string.ascii_uppercase[::-1]: |
| sortlist.append([(num, char)]) |
|
|
| sortlist.append([(num,' ')]) |
| else: |
| sortlist.append([(num,' ')]) |
| for char in string.ascii_uppercase: |
| sortlist.append([(num, char)]) |
| |
| return pd.DataFrame(sortlist + [[(">", "")]]) |
|
|
|
|
| def paired_msa_numbering(ab_seqs, fragmented = False, n_jobs = 10): |
| |
| import pandas as pd |
| |
| tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs] |
|
|
| numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering( |
| [i[0] for i in tmp_seqs], 'H', fragmented = fragmented, n_jobs = n_jobs |
| ) |
| numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering( |
| [i[1] for i in tmp_seqs], 'L', fragmented = fragmented, n_jobs = n_jobs |
| ) |
| |
| number_alignment = pd.concat([ |
| number_alignment_heavy, |
| pd.DataFrame([[("|",""), "|"]]), |
| number_alignment_light] |
| ).reset_index(drop=True) |
| |
| seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)] |
| numbered_seqs = [ |
| heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light) |
| ] |
| |
| return numbered_seqs, seqs, number_alignment |
|
|
|
|
| def unpaired_msa_numbering(seqs, chain = 'H', fragmented = False, n_jobs = 10): |
| |
| numbered_seqs = number_with_anarci(seqs, chain = chain, fragmented = fragmented, n_jobs = n_jobs) |
| number_alignment = get_number_alignment(numbered_seqs) |
| number_alignment[1] = chain |
| |
| seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs] |
| return numbered_seqs, seqs, number_alignment |
|
|
|
|
| def number_with_anarci(seqs, chain = 'H', fragmented = False, n_jobs = 1): |
| |
| import anarci |
| import pandas as pd |
|
|
| anarci_out = anarci.run_anarci( |
| pd.DataFrame(seqs).reset_index().values.tolist(), |
| ncpu=n_jobs, |
| scheme='imgt', |
| allowed_species=['human', 'mouse'], |
| ) |
| |
| numbered_seqs = [] |
| for onarci in anarci_out[1]: |
| numbered_seq = [] |
| for i in onarci[0][0]: |
| if i[1] != '-': |
| numbered_seq.append((i[0], chain, i[1])) |
| |
| if fragmented: |
| numbered_seqs.append(numbered_seq) |
| else: |
| numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")]) |
| |
| return numbered_seqs |
|
|
|
|
| def create_alignment(res_embeds, numbered_seqs, seq, number_alignment): |
| |
| import pandas as pd |
| |
| datadf = pd.DataFrame(numbered_seqs) |
| sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2] |
|
|
| idxs = np.where(sequence_alignment.values == '-')[0] |
| idxs = [idx-num for num, idx in enumerate(idxs)] |
| |
| aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs , 0, axis=0)) |
| |
| return pd.concat([aligned_embeds, sequence_alignment], axis=1).values |
|
|
|
|
| def get_spread_sequences(seq, spread, start_position): |
| """ |
| Test sequences which are 8 positions shorter (position 10 + max CDR1 gap of 7) up to 2 positions longer (possible insertions). |
| """ |
| spread_sequences = [] |
|
|
| for diff in range(start_position-8, start_position+2+1): |
| spread_sequences.append('*'*diff+seq) |
| |
| return np.array(spread_sequences) |
|
|
| def get_sequences_from_anarci(out_anarci, max_position, spread): |
| """ |
| Ensures correct masking on each side of sequence |
| """ |
| |
| if out_anarci == 'ANARCI_error': |
| return np.array(['ANARCI-ERR']*spread) |
| |
| end_position = int(re.search(r'\d+', out_anarci[::-1]).group()[::-1]) |
| |
| start_position = int(re.search(r'\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+', |
| out_anarci).group().split(',')[0]) - 1 |
| |
| sequence = "".join(re.findall(r"(?i)[A-Z*]", "".join(re.findall(r'\),\s\'[A-Z*]', out_anarci)))) |
|
|
| sequence_j = ''.join(sequence).replace('-','').replace('X','*') + '*'*(max_position-int(end_position)) |
|
|
| return get_spread_sequences(sequence_j, spread, start_position) |
|
|
|
|