| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from tqdm import tqdm |
| import os |
| from datetime import datetime |
| |
| try: |
| from torchdiffeq import odeint |
| TORCHDIFFEQ_AVAILABLE = True |
| print("✓ torchdiffeq available for proper ODE solving") |
| except ImportError: |
| TORCHDIFFEQ_AVAILABLE = False |
| print("⚠️ torchdiffeq not available, using manual Euler integration") |
|
|
| |
| from compressor_with_embeddings import Compressor, Decompressor |
| from final_flow_model import AMPFlowMatcherCFGConcat, AMPProtFlowPipelineCFG |
|
|
| class AMPGenerator: |
| """ |
| Generate AMP samples using trained ProtFlow model. |
| """ |
| |
| def __init__(self, model_path, device='cuda'): |
| self.device = device |
| |
| |
| self._load_models(model_path) |
| |
| |
| self.stats = torch.load('normalization_stats.pt', map_location=device) |
| |
| def _load_models(self, model_path): |
| """Load trained models.""" |
| print("Loading trained models...") |
| |
| |
| self.compressor = Compressor().to(self.device) |
| self.decompressor = Decompressor().to(self.device) |
| |
| self.compressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_compressor_model.pth', map_location=self.device)) |
| self.decompressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_decompressor_model.pth', map_location=self.device)) |
| |
| |
| self.flow_model = AMPFlowMatcherCFGConcat( |
| hidden_dim=480, |
| compressed_dim=80, |
| n_layers=12, |
| n_heads=16, |
| dim_ff=3072, |
| max_seq_len=25, |
| use_cfg=True |
| ).to(self.device) |
| |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| |
| |
| state_dict = checkpoint['flow_model_state_dict'] |
| new_state_dict = {} |
| |
| for key, value in state_dict.items(): |
| |
| if key.startswith('_orig_mod.'): |
| new_key = key[10:] |
| else: |
| new_key = key |
| new_state_dict[new_key] = value |
| |
| self.flow_model.load_state_dict(new_state_dict) |
| |
| print(f"✓ All models loaded successfully from step {checkpoint['step']}!") |
| print(f" Loss at checkpoint: {checkpoint['loss']:.6f}") |
| |
| |
| if TORCHDIFFEQ_AVAILABLE: |
| print("✓ Enhanced with proper ODE solving (torchdiffeq)") |
| else: |
| print("⚠️ Using fallback Euler integration") |
| |
| def _create_ode_func(self, cfg_scale=7.5): |
| """Create ODE function for torchdiffeq integration.""" |
| |
| def ode_func(t, x): |
| """ |
| ODE function: dx/dt = v_theta(x, t) |
| |
| Args: |
| t: scalar time (single float) |
| x: state tensor [B*L*D] (flattened) |
| Returns: |
| dx/dt: derivative [B*L*D] (flattened) |
| """ |
| |
| batch_size, seq_len, dim = self.current_shape |
| x = x.view(batch_size, seq_len, dim) |
| |
| |
| t_tensor = torch.full((batch_size,), t, device=self.device, dtype=x.dtype) |
| |
| |
| if cfg_scale > 0: |
| |
| amp_labels = torch.full((batch_size,), 0, device=self.device) |
| vt_cond = self.flow_model(x, t_tensor, labels=amp_labels) |
| |
| |
| mask_labels = torch.full((batch_size,), 2, device=self.device) |
| vt_uncond = self.flow_model(x, t_tensor, labels=mask_labels) |
| |
| |
| vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
| else: |
| |
| mask_labels = torch.full((batch_size,), 2, device=self.device) |
| vt = self.flow_model(x, t_tensor, labels=mask_labels) |
| |
| |
| return vt.view(-1) |
| |
| return ode_func |
| |
| def generate_amps(self, num_samples=100, num_steps=25, batch_size=32, cfg_scale=7.5, |
| ode_method='dopri5', rtol=1e-5, atol=1e-6): |
| """ |
| Generate AMP samples using flow matching with CFG and improved ODE solving. |
| |
| Args: |
| num_samples: Number of AMP samples to generate |
| num_steps: Number of ODE solving steps (25 for good quality, 1 for reflow) |
| batch_size: Batch size for generation |
| cfg_scale: CFG guidance scale (higher = stronger conditioning) |
| ode_method: ODE solver method ('dopri5', 'rk4', 'euler', 'adaptive_heun') |
| rtol: Relative tolerance for adaptive solvers |
| atol: Absolute tolerance for adaptive solvers |
| """ |
| method_str = f"{ode_method} ODE solver" if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler' else "manual Euler integration" |
| print(f"Generating {num_samples} AMP samples with {method_str} (CFG scale: {cfg_scale})...") |
| if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': |
| print(f" Method: {ode_method}, rtol={rtol}, atol={atol}") |
| |
| self.flow_model.eval() |
| self.compressor.eval() |
| self.decompressor.eval() |
| |
| all_generated = [] |
| |
| with torch.no_grad(): |
| for i in tqdm(range(0, num_samples, batch_size), desc="Generating with improved ODE"): |
| current_batch = min(batch_size, num_samples - i) |
| |
| |
| eps = torch.randn(current_batch, 25, 80, device=self.device) |
| |
| |
| if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': |
| |
| try: |
| |
| self.current_shape = eps.shape |
| |
| |
| ode_func = self._create_ode_func(cfg_scale=cfg_scale) |
| |
| |
| t_span = torch.tensor([1.0, 0.0], device=self.device, dtype=eps.dtype) |
| |
| |
| y0 = eps.view(-1) |
| |
| |
| if ode_method in ['dopri5', 'adaptive_heun']: |
| |
| solution = odeint( |
| ode_func, y0, t_span, |
| method=ode_method, |
| rtol=rtol, |
| atol=atol, |
| options={'max_num_steps': 1000} |
| ) |
| else: |
| |
| solution = odeint( |
| ode_func, y0, t_span, |
| method=ode_method, |
| options={'step_size': 0.04} |
| ) |
| |
| |
| xt = solution[-1].view(self.current_shape) |
| |
| except Exception as e: |
| print(f"⚠️ ODE solving failed for batch {i//batch_size + 1}: {e}") |
| print("Falling back to Euler method...") |
| |
| xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) |
| else: |
| |
| xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) |
| |
| |
| decompressed = self.decompressor(xt) |
| |
| |
| m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] |
| decompressed = decompressed * (mx - mn + 1e-8) + mn |
| decompressed = decompressed * s + m |
| |
| all_generated.append(decompressed.cpu()) |
| |
| |
| generated_embeddings = torch.cat(all_generated, dim=0) |
| |
| print(f"✓ Generated {generated_embeddings.shape[0]} AMP embeddings") |
| print(f" Shape: {generated_embeddings.shape}") |
| print(f" Stats - Mean: {generated_embeddings.mean():.4f}, Std: {generated_embeddings.std():.4f}") |
| |
| return generated_embeddings |
| |
| def _generate_with_euler(self, eps, current_batch, cfg_scale, num_steps): |
| """Fallback Euler integration method (original implementation).""" |
| xt = eps.clone() |
| amp_labels = torch.full((current_batch,), 0, device=self.device) |
| mask_labels = torch.full((current_batch,), 2, device=self.device) |
| |
| for step in range(num_steps): |
| t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) |
| |
| |
| if cfg_scale > 0: |
| |
| vt_cond = self.flow_model(xt, t, labels=amp_labels) |
| |
| |
| vt_uncond = self.flow_model(xt, t, labels=mask_labels) |
| |
| |
| vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
| else: |
| |
| vt = self.flow_model(xt, t, labels=mask_labels) |
| |
| |
| dt = -1.0 / num_steps |
| xt = xt + vt * dt |
| |
| return xt |
| |
| def compare_ode_methods(self, num_samples=20, cfg_scale=7.5): |
| """ |
| Compare different ODE solving methods for quality assessment. |
| """ |
| if not TORCHDIFFEQ_AVAILABLE: |
| print("⚠️ torchdiffeq not available, cannot compare ODE methods") |
| return self.generate_amps(num_samples=num_samples, cfg_scale=cfg_scale) |
| |
| methods = ['euler', 'rk4', 'dopri5', 'adaptive_heun'] |
| results = {} |
| |
| print("🔬 Comparing ODE solving methods...") |
| |
| for method in methods: |
| print(f"\n--- Testing {method} ---") |
| try: |
| start_time = torch.cuda.Event(enable_timing=True) |
| end_time = torch.cuda.Event(enable_timing=True) |
| |
| start_time.record() |
| embeddings = self.generate_amps( |
| num_samples=num_samples, |
| batch_size=10, |
| cfg_scale=cfg_scale, |
| ode_method=method |
| ) |
| end_time.record() |
| |
| torch.cuda.synchronize() |
| elapsed_time = start_time.elapsed_time(end_time) / 1000.0 |
| |
| results[method] = { |
| 'embeddings': embeddings, |
| 'time': elapsed_time, |
| 'mean': embeddings.mean().item(), |
| 'std': embeddings.std().item(), |
| 'success': True |
| } |
| |
| print(f"✓ {method}: {elapsed_time:.2f}s, mean={embeddings.mean():.4f}, std={embeddings.std():.4f}") |
| |
| except Exception as e: |
| print(f"❌ {method} failed: {e}") |
| results[method] = {'success': False, 'error': str(e)} |
| |
| return results |
| |
| def generate_with_reflow(self, num_samples=100): |
| """ |
| Generate AMP samples using 1-step reflow (if you have reflow model). |
| """ |
| print(f"Generating {num_samples} AMP samples with 1-step reflow...") |
| |
| |
| |
| return self.generate_amps(num_samples=num_samples, num_steps=1, batch_size=32) |
|
|
| def main(): |
| """Main generation function.""" |
| print("=== AMP Generation Pipeline with CFG ===") |
| |
| |
| model_path = '/data2/edwardsun/flow_checkpoints/amp_flow_model_best_optimized.pth' |
| |
| |
| try: |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) |
| print(f"✓ Found best model at step {checkpoint['step']} with loss {checkpoint['loss']:.6f}") |
| print(f" Global step: {checkpoint['global_step']}") |
| print(f" Total samples: {checkpoint['total_samples']:,}") |
| except: |
| print(f"❌ Best model not found: {model_path}") |
| print("Please train the flow matching model first using amp_flow_training.py") |
| return |
| |
| |
| generator = AMPGenerator(model_path, device='cuda') |
| |
| |
| if TORCHDIFFEQ_AVAILABLE: |
| print("\n🔬 Comparing ODE solving methods...") |
| comparison_results = generator.compare_ode_methods(num_samples=10, cfg_scale=7.5) |
| |
| |
| best_method = 'dopri5' |
| print(f"\n🚀 Using {best_method} for main generation...") |
| else: |
| best_method = 'euler' |
| print("\n⚠️ Using fallback Euler integration...") |
| |
| |
| print("\n1. Generating with CFG scale 0.0 (no conditioning)...") |
| samples_no_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=0.0, ode_method=best_method) |
| |
| print("\n2. Generating with CFG scale 3.0 (weak conditioning)...") |
| samples_weak_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=3.0, ode_method=best_method) |
| |
| print("\n3. Generating with CFG scale 7.5 (strong conditioning)...") |
| samples_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=7.5, ode_method=best_method) |
| |
| print("\n4. Generating with CFG scale 15.0 (very strong conditioning)...") |
| samples_very_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=15.0, ode_method=best_method) |
| |
| |
| output_dir = '/data2/edwardsun/generated_samples' |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| today = datetime.now().strftime('%Y%m%d') |
| |
| |
| torch.save(samples_no_cfg, os.path.join(output_dir, f'generated_amps_best_model_no_cfg_{today}.pt')) |
| torch.save(samples_weak_cfg, os.path.join(output_dir, f'generated_amps_best_model_weak_cfg_{today}.pt')) |
| torch.save(samples_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_strong_cfg_{today}.pt')) |
| torch.save(samples_very_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_very_strong_cfg_{today}.pt')) |
| |
| print("\n✓ Generation complete!") |
| print(f"Generated samples saved (Date: {today}):") |
| print(f" - generated_amps_best_model_no_cfg_{today}.pt (no conditioning)") |
| print(f" - generated_amps_best_model_weak_cfg_{today}.pt (weak CFG)") |
| print(f" - generated_amps_best_model_strong_cfg_{today}.pt (strong CFG)") |
| print(f" - generated_amps_best_model_very_strong_cfg_{today}.pt (very strong CFG)") |
| |
| print("\nCFG Analysis:") |
| print(" - CFG scale 0.0: No conditioning, generates diverse sequences") |
| print(" - CFG scale 3.0: Weak AMP conditioning") |
| print(" - CFG scale 7.5: Strong AMP conditioning (recommended)") |
| print(" - CFG scale 15.0: Very strong AMP conditioning (may be too restrictive)") |
| |
| print("\nNext steps:") |
| print("1. Decode embeddings back to sequences using ESM-2 decoder") |
| print("2. Evaluate with ProtFlow metrics (FPD, MMD, ESM-2 perplexity)") |
| print("3. Compare sequences generated with different CFG scales") |
| print("4. Evaluate AMP properties (antimicrobial activity, toxicity)") |
| if TORCHDIFFEQ_AVAILABLE: |
| print(f"5. ✓ Enhanced generation with {best_method} ODE solver") |
| else: |
| print("5. Install torchdiffeq for improved ODE solving: pip install torchdiffeq") |
|
|
| if __name__ == "__main__": |
| main() |