| import argparse |
| import math |
| import torch |
| import torch.nn as nn |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
| |
| def modulate(x, shift, scale): |
| """ |
| Modulates the input tensor x with a shift and scale. |
| """ |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
| class TimestepEmbedder(nn.Module): |
| """ |
| Embeds a continuous scalar timestep t in [0, 1] into a vector representation. |
| """ |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(1, hidden_size, bias=True), |
| nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size, bias=True), |
| ) |
|
|
| def forward(self, t): |
| |
| return self.mlp(t.unsqueeze(-1)) |
|
|
| class DiTBlock(nn.Module): |
| """ |
| A single block of the Diffusion Transformer. |
| """ |
| def __init__(self, hidden_size, n_heads): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True) |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_size, 4 * hidden_size), |
| nn.GELU(), |
| nn.Linear(4 * hidden_size, hidden_size) |
| ) |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
| ) |
|
|
| def forward(self, x, c): |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) |
| x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) |
| attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1) |
| x = x + gate_msa.unsqueeze(1) * attn_output |
| x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| mlp_output = self.mlp(x_norm2) |
| x = x + gate_mlp.unsqueeze(1) * mlp_output |
| return x |
|
|
| class MDLM(nn.Module): |
| """ |
| Masked Diffusion Language Model (MDLM) using a DiT backbone. |
| """ |
| def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.seq_len = seq_len |
| self.model_dim = model_dim |
| self.mask_token_id = vocab_size |
|
|
| self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) |
| self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim)) |
| self.time_embedder = TimestepEmbedder(model_dim) |
|
|
| self.transformer_blocks = nn.ModuleList([ |
| DiTBlock(model_dim, n_heads) for _ in range(n_layers) |
| ]) |
|
|
| self.final_norm = nn.LayerNorm(model_dim) |
| self.lm_head = nn.Linear(model_dim, vocab_size) |
|
|
| def forward(self, x, t): |
| seq_len = x.shape[1] |
| x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :] |
| t_embed = self.time_embedder(t) |
| for block in self.transformer_blocks: |
| x_embed = block(x_embed, t_embed) |
| x_embed = self.final_norm(x_embed) |
| logits = self.lm_head(x_embed) |
| return logits |
|
|
| |
|
|
| def generate_samples(model, device, num_samples, seq_len, steps, temperature): |
| """ |
| Generates samples by starting from a random sequence and progressively refining it. |
| """ |
| model.eval() |
| |
| |
| shape = (num_samples, seq_len) |
| x = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device) |
| |
| |
| |
| keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len |
| keep_schedule = torch.round(keep_schedule).long() |
|
|
| with torch.no_grad(): |
| progress_bar = tqdm(range(steps), desc="Generating Samples") |
| for i in progress_bar: |
| |
| t_continuous = torch.full((num_samples,), (i) / steps, device=device) |
| |
| logits = model(x, t_continuous) |
| |
| |
| scaled_logits = logits / temperature |
| probs = torch.nn.functional.softmax(scaled_logits, dim=-1) |
| |
| |
| sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(shape) |
| |
| |
| if i == steps - 1: |
| x = sampled_tokens |
| break |
|
|
| |
| confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1) |
| |
| |
| num_to_keep = keep_schedule[i] |
| _, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1) |
| |
| |
| keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True) |
| |
| |
| |
| |
| x = torch.where(keep_mask, sampled_tokens, x) |
| |
| return x |
|
|
| |
|
|
| def main(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| print(f"Loading checkpoint from {args.checkpoint}...") |
| try: |
| checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| model_args = checkpoint['args'] |
| except FileNotFoundError: |
| print(f"Error: Checkpoint file not found at {args.checkpoint}") |
| return |
| except Exception as e: |
| print(f"Error loading checkpoint: {e}") |
| return |
|
|
| print("Initializing model...") |
| model = MDLM( |
| vocab_size=model_args.vocab_size, |
| seq_len=model_args.seq_len, |
| model_dim=model_args.model_dim, |
| n_heads=model_args.n_heads, |
| n_layers=model_args.n_layers |
| ).to(device) |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| print("Model loaded successfully.") |
|
|
| gen_len = args.gen_len if args.gen_len is not None else model_args.seq_len |
| if gen_len > model_args.seq_len: |
| raise ValueError(f"Requested generation length ({gen_len}) is greater than the model's max length ({model_args.seq_len}).") |
| print(f"Generating sequences of length {gen_len}.") |
|
|
| generated_tokens = generate_samples( |
| model=model, |
| device=device, |
| num_samples=args.num_samples, |
| seq_len=gen_len, |
| steps=args.gen_steps, |
| temperature=args.temperature |
| ) |
| |
| print("Decoding and saving samples...") |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| |
| with open(args.output_file, 'w') as f: |
| for sample_tokens in generated_tokens: |
| sequence = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False) |
| clean_sequence = sequence.replace(" ", "")[5:-5] |
| f.write(clean_sequence + "\n") |
| print(clean_sequence) |
| |
| print(f"Generation complete. {args.num_samples} sequences saved to {args.output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Generate samples from a trained ReDi (MDLM) model starting from random noise.") |
|
|
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model checkpoint file.") |
| parser.add_argument("--num_samples", type=int, default=128, help="Number of samples to generate.") |
| parser.add_argument("--output_file", type=str, default="./generated_peptides.txt", help="File to save the generated peptide sequences.") |
| parser.add_argument("--gen_steps", type=int, default=16, help="Number of steps for the progressive refinement process.") |
| parser.add_argument("--gen_len", type=int, default=None, help="Desired length of the generated sequences. Defaults to the model's maximum trained length.") |
| parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. >1 increases diversity, <1 decreases it.") |
|
|
| args = parser.parse_args() |
| main(args) |
|
|