| import os |
| import torch |
| import stat |
| import re |
| import time |
| import argparse |
| import numpy as np |
|
|
| from functools import partial |
| from typing import List, Tuple |
|
|
| import torch.distributed as dist |
| from sat.helpers import print_rank0 |
| from sat import mpu, get_args, get_tokenizer |
| from utils import AdvancedBaseStrategy, BeamSearchStrategy |
| from model_utils import MSAGPT, FineTuneMSAGPT |
| from utils import chat_api |
|
|
|
|
|
|
| if __name__ == "__main__": |
| py_parser = argparse.ArgumentParser(add_help=False) |
| py_parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.") |
| py_parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.") |
| py_parser.add_argument("--max-gen-length", type=int, default=512, help="The minimum length each blank should generate.") |
| py_parser.add_argument("--is-valid", action="store_true", help="Print all output generated by beam search strategy.") |
| py_parser.add_argument("--print-all-beams", action="store_true", help="Print all output generated by beam search strategy.") |
| py_parser.add_argument("--multiline_stream", action="store_true", help="streaming multiline output.") |
| py_parser.add_argument("--no-gap", action="store_true", help="do not generate gaps.") |
| py_parser.add_argument("--from_pretrained", type=str, default="./checkpoints/MSAGPT", help='pretrained ckpt') |
| py_parser.add_argument("--chinese", action='store_true', help='Chinese interface') |
| py_parser.add_argument("--stream_chat", action='store_true', help='streaming output') |
|
|
|
|
| py_parser = MSAGPT.add_model_specific_args(py_parser) |
| known, args_list = py_parser.parse_known_args() |
| args = get_args(args_list) |
| args = argparse.Namespace(**vars(args), **vars(known)) |
| model, args = MSAGPT.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {}) |
| model.eval() |
| rank = int(os.environ.get('RANK', 0)) |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) |
| if torch.cuda.is_available(): |
| model = model.to('cuda') |
| from utils import proteinglm_tokenizer |
| tokenizer = proteinglm_tokenizer() |
|
|
| end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")] |
| |
| invalid_slices = [0,26,28,29,30,31,32] |
| if args.no_gap: |
| invalid_slices.append(tokenizer.TokenToId('-')) |
| if args.sampling_strategy == "BaseStrategy": |
| assert not args.print_all_beams, "BaseStrategy don't support print all beams." |
| strategy = AdvancedBaseStrategy( |
| batch_size=1, invalid_slices = invalid_slices, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, min_gen_length=args.min_gen_length, no_repeat_ngram_size=args.no_repeat_ngram_size, end_tokens=end_tokens |
| ) |
| elif args.sampling_strategy == "BeamSearchStrategy": |
| strategy = BeamSearchStrategy( |
| 1, |
| args.num_beams, |
| length_penalty=args.length_penalty, |
| consider_end=True, |
| end_tokens=end_tokens, |
| invalid_slices=invalid_slices, |
| no_repeat_ngram_size=args.no_repeat_ngram_size, |
| min_gen_length=args.min_gen_length, |
| deterministic=True |
| ) |
| else: |
| raise ValueError(f"unknown strategy {args.sampling_strategy}") |
|
|
|
|
|
|
| if args.input_source == 'chat': |
| if args.chinese: |
| if rank == 0: |
| print('欢迎使用 MSAGPT-CLI ,输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以"<M>"相连),例如:"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG",其中"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG"为主序列,"VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG"为MSA prompt。 stop 终止程序'.center(20, "*")) |
| else: |
| if rank == 0: |
| print('Welcome to MSAGPT-CLI. Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by "<M>"), for example: "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG", where "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG" is the main sequence, and "VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG" are MSA prompts. Type "stop" to end the program.'.center(20,"*")) |
| with torch.no_grad(): |
| while True: |
| if args.chinese: |
| if rank == 0: |
| protein_input = input("请输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以'<M>'相连):") |
| else: |
| protein_input = None |
| else: |
| if rank == 0: |
| protein_input = input("Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by '<M>': ") |
| else: |
| protein_input = None |
| if world_size > 1: |
| torch.distributed.broadcast_object(protein_input, 0) |
| protein_input = protein_input.strip() |
| assert protein_input is not None |
|
|
| if protein_input == 'stop': |
| break |
| |
| try: |
| response = chat_api( |
| args=args, |
| query=protein_input, |
| model=model, |
| tokenizer=tokenizer, |
| strategy=strategy |
| ) |
| except Exception as e: |
| print(e) |
| break |
| if rank == 0 and not args.stream_chat: |
| if args.chinese: |
| print(f"{'生成的MSA'.center(20, '*')}") |
| else: |
| print(f"{'Virtual MSA'.center(20, '*')}") |
| if args.print_all_beams: |
| for idx, gen in enumerate(response): |
| out_str = f"Beam: {idx}".center(11,'@') |
| print(out_str) |
| for _ in gen: |
| print(_) |
| print() |
| else: |
| response = response[0] |
| for _ in response: |
| print(_) |
| print() |
| else: |
| chat_api( |
| args=args, |
| model=model, |
| tokenizer=tokenizer, |
| strategy=strategy |
| ) |