| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import subprocess |
| import sys |
|
|
| try: |
| import sympy |
| _ = sympy.core |
| except (ImportError, AttributeError): |
| subprocess.check_call( |
| [sys.executable, "-m", "pip", "install", "--upgrade", "sympy", "--break-system-packages", "-q"] |
| ) |
|
|
| import gc |
| import os |
| import shutil |
| import time |
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import ( |
| Audio, |
| Dataset as HFDataset, |
| Features, |
| Sequence, |
| Value, |
| Array2D, |
| concatenate_datasets, |
| load_dataset, |
| load_from_disk, |
| ) |
|
|
| |
| |
| |
|
|
| @dataclass |
| class BaseConfig: |
| cache_dir: str = "/home/claude/geo_cache" |
| max_text_len: int = 32 |
|
|
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| amp_enabled: bool = torch.cuda.is_available() |
|
|
| bert_model_name: str = "google-bert/bert-large-uncased" |
| bert_hidden_dim: int = 1024 |
|
|
| batch_size: int = 256 |
| num_workers: int = 8 |
| prefetch_factor: int = 2 |
| pin_memory: bool = torch.cuda.is_available() |
|
|
| shard_size_default: int = 2048 |
|
|
| |
| max_audio_samples: int = 10000 |
| max_protein_samples: int = 15000 |
| max_code_samples: int = 50000 |
|
|
| cleanup_hf_cache_between_experts: bool = True |
|
|
|
|
| CFG = BaseConfig() |
| DEVICE = torch.device(CFG.device) |
|
|
|
|
| |
| |
| |
|
|
| def cleanup_hf_cache() -> None: |
| """Delete HF datasets/hub cache to free disk after encoding an expert.""" |
| hf_cache = os.path.expanduser("~/.cache/huggingface") |
| for subdir in ["datasets", "hub"]: |
| p = os.path.join(hf_cache, subdir) |
| if not os.path.exists(p): |
| continue |
|
|
| size_gb = 0.0 |
| for dp, _, files in os.walk(p): |
| for f in files: |
| fp = os.path.join(dp, f) |
| try: |
| size_gb += os.path.getsize(fp) |
| except OSError: |
| pass |
| size_gb /= 1e9 |
|
|
| print(f" Cleaning {p} ({size_gb:.1f} GB)...") |
| shutil.rmtree(p, ignore_errors=True) |
| os.makedirs(p, exist_ok=True) |
|
|
|
|
| def cleanup_cuda() -> None: |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| |
| |
|
|
| _bert_tokenizer = None |
|
|
| def get_bert_tokenizer(): |
| global _bert_tokenizer |
| if _bert_tokenizer is None: |
| from transformers import BertTokenizer |
| _bert_tokenizer = BertTokenizer.from_pretrained(CFG.bert_model_name) |
| return _bert_tokenizer |
|
|
|
|
| def load_shared_bert(): |
| from transformers import BertModel |
| print("Loading shared BERT-large...") |
| bert = BertModel.from_pretrained( |
| CFG.bert_model_name, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ).to(DEVICE).eval() |
| print(" BERT ready.") |
| return bert |
|
|
|
|
| |
| |
| |
|
|
| def ensure_dir(path: str) -> None: |
| os.makedirs(path, exist_ok=True) |
|
|
|
|
| def make_loader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader: |
| kwargs = dict( |
| dataset=ds, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=CFG.pin_memory, |
| persistent_workers=num_workers > 0, |
| ) |
| if num_workers > 0: |
| kwargs["prefetch_factor"] = CFG.prefetch_factor |
| return DataLoader(**kwargs) |
|
|
|
|
| def masked_text_tokenize(text: str, tokenizer) -> Tuple[torch.Tensor, torch.Tensor]: |
| tok = tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=CFG.max_text_len, |
| return_tensors="pt", |
| ) |
| return tok["input_ids"].squeeze(0), tok["attention_mask"].squeeze(0) |
|
|
|
|
| def extract_first_text(sample: Dict[str, Any], keys: List[str]) -> str: |
| for key in keys: |
| if key not in sample: |
| continue |
| value = sample[key] |
|
|
| if isinstance(value, str): |
| value = value.strip() |
| if value: |
| return value |
|
|
| if isinstance(value, list) and value: |
| first = value[0] |
| if isinstance(first, str): |
| first = first.strip() |
| if first: |
| return first |
| if isinstance(first, dict): |
| txt = str(first.get("raw", first.get("text", ""))).strip() |
| if txt: |
| return txt |
| txt = str(first).strip() |
| if txt: |
| return txt |
|
|
| return "" |
|
|
|
|
| |
| |
| |
|
|
| class ShardWriter: |
| def __init__( |
| self, |
| cache_dir: str, |
| tag: str, |
| features: Features, |
| shard_size: int, |
| row_keys: List[str], |
| ): |
| self.cache_dir = cache_dir |
| self.tag = tag |
| self.features = features |
| self.shard_size = shard_size |
| self.row_keys = row_keys |
|
|
| self.cache_path = os.path.join(cache_dir, tag) |
| self.shard_root = os.path.join(cache_dir, f"{tag}__shards") |
|
|
| self.rows = {k: [] for k in row_keys} |
| self.shard_paths: List[str] = [] |
| self.shard_idx = 0 |
| self.n_written = 0 |
|
|
| @property |
| def exists(self) -> bool: |
| return os.path.exists(self.cache_path) |
|
|
| def add_row(self, row: Dict[str, Any]) -> None: |
| for k in self.row_keys: |
| self.rows[k].append(row[k]) |
|
|
| if len(self.rows[self.row_keys[0]]) >= self.shard_size: |
| self.flush() |
|
|
| def flush(self) -> None: |
| n_rows = len(self.rows[self.row_keys[0]]) |
| if n_rows == 0: |
| return |
|
|
| ensure_dir(self.shard_root) |
| shard_path = os.path.join(self.shard_root, f"shard_{self.shard_idx:05d}") |
| ds = HFDataset.from_dict(self.rows, features=self.features) |
| ds.save_to_disk(shard_path) |
|
|
| self.shard_paths.append(shard_path) |
| self.shard_idx += 1 |
| self.n_written += n_rows |
| self.rows = {k: [] for k in self.row_keys} |
|
|
| def finalize(self) -> str: |
| self.flush() |
|
|
| print(f" Merging {len(self.shard_paths)} shards...") |
| merged = concatenate_datasets([load_from_disk(p) for p in self.shard_paths]) |
| merged.save_to_disk(self.cache_path) |
| print(f" Saved {len(merged)} pairs to {self.cache_path}") |
|
|
| if os.path.exists(self.shard_root): |
| shutil.rmtree(self.shard_root, ignore_errors=True) |
|
|
| return self.cache_path |
|
|
|
|
| |
| |
| |
|
|
| class ImageTextDataset(Dataset): |
| def __init__(self, hf_ds, bert_tokenizer, image_processor): |
| self.ds = hf_ds |
| self.tok = bert_tokenizer |
| self.proc = image_processor |
| self.fallback_shape = (3, 518, 518) |
|
|
| def __len__(self): |
| return len(self.ds) |
|
|
| def __getitem__(self, idx): |
| sample = self.ds[idx] |
|
|
| caption = extract_first_text( |
| sample, |
| ["answer", "caption", "captions", "text", "original_alt_text"], |
| ) |
| ids, mask = masked_text_tokenize(caption, self.tok) |
|
|
| image = sample.get("image", None) |
| valid = True |
|
|
| if image is not None and hasattr(image, "convert"): |
| try: |
| expert_input = self.proc( |
| images=image.convert("RGB"), |
| return_tensors="pt", |
| )["pixel_values"].squeeze(0) |
| except Exception: |
| expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32) |
| valid = False |
| else: |
| expert_input = torch.zeros(self.fallback_shape, dtype=torch.float32) |
| valid = False |
|
|
| return ids, mask, expert_input, valid |
|
|
|
|
| class CodeTextDataset(Dataset): |
| def __init__(self, hf_ds, bert_tokenizer, code_tokenizer): |
| self.ds = hf_ds |
| self.tok = bert_tokenizer |
| self.code_tok = code_tokenizer |
|
|
| def __len__(self): |
| return len(self.ds) |
|
|
| def __getitem__(self, idx): |
| sample = self.ds[idx] |
|
|
| doc = sample.get("func_documentation_string", "") |
| if not doc or not doc.strip(): |
| doc = str(sample.get("whole_func_string", ""))[:200] |
| doc = str(doc).strip()[:500] |
|
|
| ids, mask = masked_text_tokenize(doc, self.tok) |
|
|
| code = str(sample.get("func_code_string", sample.get("whole_func_string", ""))).strip()[:512] |
| valid = len(code) > 5 and len(doc) > 5 |
|
|
| if valid: |
| try: |
| tok = self.code_tok( |
| code, |
| padding="max_length", |
| truncation=True, |
| max_length=256, |
| return_tensors="pt", |
| ) |
| code_ids = tok["input_ids"].squeeze(0) |
| code_mask = tok["attention_mask"].squeeze(0) |
| except Exception: |
| code_ids = torch.zeros(256, dtype=torch.long) |
| code_mask = torch.zeros(256, dtype=torch.long) |
| valid = False |
| else: |
| code_ids = torch.zeros(256, dtype=torch.long) |
| code_mask = torch.zeros(256, dtype=torch.long) |
|
|
| return ids, mask, torch.stack([code_ids, code_mask]), valid |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def encode_map_dataset( |
| *, |
| tag: str, |
| loader: DataLoader, |
| bert, |
| expert_name: str, |
| expert_hidden_shape: Tuple[int, int], |
| expert_forward: Callable[[torch.Tensor], torch.Tensor], |
| shard_size: int, |
| max_samples: Optional[int] = None, |
| ) -> str: |
| cache_path = os.path.join(CFG.cache_dir, tag) |
| if os.path.exists(cache_path): |
| print(f" Cache exists: {cache_path}") |
| return cache_path |
|
|
| features = Features({ |
| "text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"), |
| "text_mask": Sequence(Value("bool"), length=CFG.max_text_len), |
| f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"), |
| }) |
|
|
| writer = ShardWriter( |
| cache_dir=CFG.cache_dir, |
| tag=tag, |
| features=features, |
| shard_size=shard_size, |
| row_keys=["text_hidden", "text_mask", f"{expert_name}_hidden"], |
| ) |
|
|
| t0 = time.time() |
| n = 0 |
|
|
| for batch in loader: |
| text_ids, text_mask, expert_input, valid = batch |
| valid_b = valid.bool() |
|
|
| if not valid_b.any(): |
| continue |
|
|
| text_ids = text_ids[valid_b].to(DEVICE, non_blocking=True) |
| text_mask_gpu = text_mask[valid_b].to(DEVICE, non_blocking=True) |
| expert_input = expert_input[valid_b].to(DEVICE, non_blocking=True) |
|
|
| text_hidden = bert( |
| input_ids=text_ids, |
| attention_mask=text_mask_gpu, |
| ).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy() |
|
|
| text_mask_np = text_mask_gpu.bool().cpu().numpy() |
| expert_hidden = expert_forward(expert_input).detach().to(dtype=torch.float16).cpu().numpy() |
|
|
| for i in range(text_hidden.shape[0]): |
| writer.add_row({ |
| "text_hidden": text_hidden[i], |
| "text_mask": text_mask_np[i].tolist(), |
| f"{expert_name}_hidden": expert_hidden[i], |
| }) |
|
|
| n += text_hidden.shape[0] |
| if n % 1000 < CFG.batch_size or n <= CFG.batch_size: |
| rate = n / max(time.time() - t0, 1e-6) |
| print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)") |
|
|
| if max_samples is not None and n >= max_samples: |
| break |
|
|
| final_path = writer.finalize() |
| print(f" Completed {n} samples in {time.time() - t0:.0f}s") |
| return final_path |
|
|
|
|
| |
| |
| |
|
|
| def decode_audio_obj(audio_obj) -> Tuple[np.ndarray, int]: |
| if hasattr(audio_obj, "get_all_samples"): |
| samples = audio_obj.get_all_samples() |
| arr = samples.data.numpy().squeeze() |
| sr = samples.sample_rate |
| return arr, sr |
|
|
| if isinstance(audio_obj, dict): |
| return audio_obj["array"], audio_obj.get("sampling_rate", 16000) |
|
|
| raise TypeError(f"Unsupported audio object type: {type(audio_obj)}") |
|
|
|
|
| def stream_librispeech_batches( |
| stream, |
| bert_tokenizer, |
| whisper_processor, |
| batch_size: int, |
| ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| batch_ids = [] |
| batch_masks = [] |
| batch_mels = [] |
|
|
| for sample in stream: |
| text = sample.get("text", sample.get("transcription", "")) |
| audio_obj = sample.get("audio") |
| if not text or audio_obj is None: |
| continue |
|
|
| try: |
| audio_array, sr = decode_audio_obj(audio_obj) |
| except Exception: |
| continue |
|
|
| ids, mask = masked_text_tokenize(str(text), bert_tokenizer) |
|
|
| try: |
| mel = whisper_processor( |
| audio_array, |
| sampling_rate=sr, |
| return_tensors="pt", |
| ).input_features.squeeze(0) |
| except Exception: |
| continue |
|
|
| batch_ids.append(ids) |
| batch_masks.append(mask) |
| batch_mels.append(mel) |
|
|
| if len(batch_ids) >= batch_size: |
| yield ( |
| torch.stack(batch_ids), |
| torch.stack(batch_masks), |
| torch.stack(batch_mels), |
| ) |
| batch_ids, batch_masks, batch_mels = [], [], [] |
|
|
| if batch_ids: |
| yield ( |
| torch.stack(batch_ids), |
| torch.stack(batch_masks), |
| torch.stack(batch_mels), |
| ) |
|
|
|
|
| def extract_protein_caption(sample: Dict[str, Any]) -> str: |
| convos = sample.get("conversations", []) |
| if isinstance(convos, list): |
| for c in convos: |
| if isinstance(c, dict) and c.get("from") == "gpt": |
| v = str(c.get("value", "")).strip() |
| if v: |
| return v[:500] |
| for c in convos: |
| if isinstance(c, dict) and "value" in c: |
| v = str(c["value"]).strip() |
| if v: |
| return v[:500] |
|
|
| return str(sample.get("protein", "")).strip()[:500] |
|
|
|
|
| def stream_protein_batches( |
| stream, |
| bert_tokenizer, |
| esm_tokenizer, |
| batch_size: int, |
| ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: |
| batch_ids = [] |
| batch_masks = [] |
| batch_esm_ids = [] |
| batch_esm_masks = [] |
|
|
| for sample in stream: |
| caption = extract_protein_caption(sample) |
| seq = str(sample.get("amino_seq", sample.get("protein_sequence", ""))).strip() |
|
|
| if len(caption) < 5 or len(seq) < 5: |
| continue |
|
|
| ids, mask = masked_text_tokenize(caption, bert_tokenizer) |
|
|
| try: |
| esm_t = esm_tokenizer( |
| seq, |
| padding="max_length", |
| truncation=True, |
| max_length=512, |
| return_tensors="pt", |
| ) |
| except Exception: |
| continue |
|
|
| batch_ids.append(ids) |
| batch_masks.append(mask) |
| batch_esm_ids.append(esm_t["input_ids"].squeeze(0)) |
| batch_esm_masks.append(esm_t["attention_mask"].squeeze(0)) |
|
|
| if len(batch_ids) >= batch_size: |
| yield ( |
| torch.stack(batch_ids), |
| torch.stack(batch_masks), |
| torch.stack(batch_esm_ids), |
| torch.stack(batch_esm_masks), |
| ) |
| batch_ids, batch_masks, batch_esm_ids, batch_esm_masks = [], [], [], [] |
|
|
| if batch_ids: |
| yield ( |
| torch.stack(batch_ids), |
| torch.stack(batch_masks), |
| torch.stack(batch_esm_ids), |
| torch.stack(batch_esm_masks), |
| ) |
|
|
|
|
| @torch.no_grad() |
| def encode_streaming_batches( |
| *, |
| tag: str, |
| expert_name: str, |
| expert_hidden_shape: Tuple[int, int], |
| batch_iter: Iterable, |
| bert, |
| expert_batch_forward: Callable[..., torch.Tensor], |
| shard_size: int, |
| row_keys: List[str], |
| max_samples: Optional[int] = None, |
| ) -> str: |
| cache_path = os.path.join(CFG.cache_dir, tag) |
| if os.path.exists(cache_path): |
| print(f" Cache exists: {cache_path}") |
| return cache_path |
|
|
| features = Features({ |
| "text_hidden": Array2D(shape=(CFG.max_text_len, CFG.bert_hidden_dim), dtype="float16"), |
| "text_mask": Sequence(Value("bool"), length=CFG.max_text_len), |
| f"{expert_name}_hidden": Array2D(shape=expert_hidden_shape, dtype="float16"), |
| }) |
|
|
| writer = ShardWriter( |
| cache_dir=CFG.cache_dir, |
| tag=tag, |
| features=features, |
| shard_size=shard_size, |
| row_keys=row_keys, |
| ) |
|
|
| t0 = time.time() |
| n = 0 |
|
|
| for packed in batch_iter: |
| |
| text_ids = packed[0].to(DEVICE, non_blocking=True) |
| text_mask = packed[1].to(DEVICE, non_blocking=True) |
|
|
| text_hidden = bert( |
| input_ids=text_ids, |
| attention_mask=text_mask, |
| ).last_hidden_state.detach().to(dtype=torch.float16).cpu().numpy() |
|
|
| text_mask_np = text_mask.bool().cpu().numpy() |
|
|
| expert_hidden = expert_batch_forward(*[p.to(DEVICE, non_blocking=True) for p in packed[2:]]) |
| expert_hidden = expert_hidden.detach().to(dtype=torch.float16).cpu().numpy() |
|
|
| for i in range(text_hidden.shape[0]): |
| writer.add_row({ |
| "text_hidden": text_hidden[i], |
| "text_mask": text_mask_np[i].tolist(), |
| f"{expert_name}_hidden": expert_hidden[i], |
| }) |
|
|
| n += text_hidden.shape[0] |
| batch_size = text_hidden.shape[0] |
| if n % 1000 < batch_size or n <= batch_size: |
| rate = n / max(time.time() - t0, 1e-6) |
| print(f" {n}" + (f"/{max_samples}" if max_samples else "") + f" ({rate:.0f}/s)") |
|
|
| if max_samples is not None and n >= max_samples: |
| break |
|
|
| final_path = writer.finalize() |
| print(f" Completed {n} samples in {time.time() - t0:.0f}s") |
| return final_path |
|
|
|
|
| |
| |
| |
|
|
| def encode_image_expert(bert, split: str, tag: str, max_samples: Optional[int] = None) -> str: |
| from transformers import Dinov2Model, AutoImageProcessor |
|
|
| print(f"\n [IMAGE] Loading DINOv2-large + COCO-Caption ({split})...") |
| dino = Dinov2Model.from_pretrained( |
| "facebook/dinov2-large", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ).to(DEVICE).eval() |
| proc = AutoImageProcessor.from_pretrained("facebook/dinov2-large") |
| tok = get_bert_tokenizer() |
|
|
| hf_ds = load_dataset("lmms-lab/COCO-Caption", split=split) |
| if max_samples is not None: |
| hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds)))) |
| print(f" Dataset: {len(hf_ds)} samples") |
|
|
| torch_ds = ImageTextDataset(hf_ds, tok, proc) |
| loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers) |
|
|
| def expert_forward(pixel_values): |
| return dino(pixel_values=pixel_values).last_hidden_state |
|
|
| path = encode_map_dataset( |
| tag=tag, |
| loader=loader, |
| bert=bert, |
| expert_name="image", |
| expert_hidden_shape=(257, 1024), |
| expert_forward=expert_forward, |
| shard_size=CFG.shard_size_default, |
| max_samples=max_samples, |
| ) |
|
|
| del dino, proc, hf_ds, torch_ds, loader |
| cleanup_cuda() |
| return path |
|
|
|
|
| def encode_code_expert(bert, max_samples: Optional[int] = None) -> str: |
| from transformers import RobertaModel, RobertaTokenizer |
|
|
| print("\n [CODE] Loading CodeBERT + CodeSearchNet python...") |
| codebert = RobertaModel.from_pretrained( |
| "microsoft/codebert-base", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ).to(DEVICE).eval() |
| code_tok = RobertaTokenizer.from_pretrained("microsoft/codebert-base") |
| tok = get_bert_tokenizer() |
|
|
| hf_ds = load_dataset("code-search-net/code_search_net", "python", split="train") |
| if max_samples is not None: |
| hf_ds = hf_ds.select(range(min(max_samples, len(hf_ds)))) |
|
|
| hf_ds = hf_ds.filter( |
| lambda x: bool(x.get("func_documentation_string", "").strip()), |
| num_proc=4, |
| ) |
| print(f" Dataset: {len(hf_ds)} samples (after filtering)") |
|
|
| torch_ds = CodeTextDataset(hf_ds, tok, code_tok) |
| loader = make_loader(torch_ds, batch_size=CFG.batch_size, num_workers=CFG.num_workers) |
|
|
| def expert_forward(packed): |
| code_ids = packed[:, 0].long() |
| code_mask = packed[:, 1].long() |
| return codebert(input_ids=code_ids, attention_mask=code_mask).last_hidden_state |
|
|
| path = encode_map_dataset( |
| tag="code_csn", |
| loader=loader, |
| bert=bert, |
| expert_name="code", |
| expert_hidden_shape=(256, 768), |
| expert_forward=expert_forward, |
| shard_size=CFG.shard_size_default, |
| max_samples=max_samples, |
| ) |
|
|
| del codebert, code_tok, hf_ds, torch_ds, loader |
| cleanup_cuda() |
| return path |
|
|
|
|
| def encode_audio_expert(bert, max_samples: Optional[int] = None) -> str: |
| from transformers import WhisperModel, WhisperFeatureExtractor |
|
|
| print("\n [AUDIO] Loading Whisper-large-v3 + LibriSpeech ASR (streaming)...") |
| whisper = WhisperModel.from_pretrained( |
| "openai/whisper-large-v3", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ).to(DEVICE).eval() |
| proc = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3") |
| tok = get_bert_tokenizer() |
|
|
| max_n = max_samples or CFG.max_audio_samples |
| audio_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| probe_stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True) |
| probe_stream = probe_stream.cast_column("audio", Audio(sampling_rate=16000)) |
| first = next(iter(probe_stream)) |
| arr, sr = decode_audio_obj(first["audio"]) |
|
|
| mel = proc(arr, sampling_rate=sr, return_tensors="pt").input_features |
| mel = mel.to(device=DEVICE, dtype=audio_dtype) |
|
|
| with torch.no_grad(): |
| probe_hidden = whisper.encoder(mel).last_hidden_state |
|
|
| seq_len, hidden_dim = probe_hidden.shape[1], probe_hidden.shape[2] |
| print(f" Whisper encoder output: ({seq_len}, {hidden_dim})") |
| del mel, probe_hidden |
|
|
| stream = load_dataset("openslr/librispeech_asr", "clean", split="train.100", streaming=True) |
| stream = stream.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
| batch_iter = stream_librispeech_batches( |
| stream=stream, |
| bert_tokenizer=tok, |
| whisper_processor=proc, |
| batch_size=16, |
| ) |
|
|
| def expert_batch_forward(mels: torch.Tensor) -> torch.Tensor: |
| mels = mels.to(dtype=audio_dtype) |
| return whisper.encoder(mels).last_hidden_state |
|
|
| path = encode_streaming_batches( |
| tag="audio_librispeech", |
| expert_name="audio", |
| expert_hidden_shape=(seq_len, hidden_dim), |
| batch_iter=batch_iter, |
| bert=bert, |
| expert_batch_forward=expert_batch_forward, |
| shard_size=256, |
| row_keys=["text_hidden", "text_mask", "audio_hidden"], |
| max_samples=max_n, |
| ) |
|
|
| del whisper, proc |
| cleanup_cuda() |
| return path |
|
|
|
|
| def encode_protein_expert(bert, max_samples: Optional[int] = None) -> str: |
| from transformers import EsmModel, EsmTokenizer |
|
|
| print("\n [PROTEIN] Loading ESM-2-650M + Protein2Text-QA (streaming)...") |
| esm = EsmModel.from_pretrained( |
| "facebook/esm2_t33_650M_UR50D", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| ).to(DEVICE).eval() |
| esm_tok = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| tok = get_bert_tokenizer() |
|
|
| max_n = max_samples or CFG.max_protein_samples |
| stream = load_dataset("tumorailab/Protein2Text-QA", split="test", streaming=True) |
|
|
| batch_iter = stream_protein_batches( |
| stream=stream, |
| bert_tokenizer=tok, |
| esm_tokenizer=esm_tok, |
| batch_size=32, |
| ) |
|
|
| def expert_batch_forward(esm_ids: torch.Tensor, esm_mask: torch.Tensor) -> torch.Tensor: |
| return esm(input_ids=esm_ids.long(), attention_mask=esm_mask.long()).last_hidden_state |
|
|
| path = encode_streaming_batches( |
| tag="protein_p2t", |
| expert_name="protein", |
| expert_hidden_shape=(512, 1280), |
| batch_iter=batch_iter, |
| bert=bert, |
| expert_batch_forward=expert_batch_forward, |
| shard_size=512, |
| row_keys=["text_hidden", "text_mask", "protein_hidden"], |
| max_samples=max_n, |
| ) |
|
|
| del esm, esm_tok |
| cleanup_cuda() |
| return path |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| ensure_dir(CFG.cache_dir) |
|
|
| print("=" * 70) |
| print("STAGE 1: MULTI-EXPERT PRECOMPUTE") |
| print("=" * 70) |
|
|
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name()}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| print(f"Cache: {CFG.cache_dir}") |
|
|
| required_tags = [ |
| "image_coco", |
| "image_coco_test", |
| "audio_librispeech", |
| "protein_p2t", |
| "code_csn", |
| ] |
| missing = [t for t in required_tags if not os.path.exists(os.path.join(CFG.cache_dir, t))] |
|
|
| if not missing: |
| print("\nAll caches exist. Nothing to encode.") |
| bert = None |
| else: |
| print(f"\nMissing caches: {missing}") |
| if CFG.cleanup_hf_cache_between_experts: |
| cleanup_hf_cache() |
| bert = load_shared_bert() |
|
|
| paths: Dict[str, Optional[str]] = {} |
|
|
| |
| print(f"\n{'─' * 50}") |
| print("[1/4] IMAGE — DINOv2 + COCO-Caption") |
| if os.path.exists(os.path.join(CFG.cache_dir, "image_coco")): |
| print(" [IMAGE] Cache exists, skipping.") |
| paths["image"] = os.path.join(CFG.cache_dir, "image_coco") |
| else: |
| paths["image"] = encode_image_expert(bert, split="val", tag="image_coco") |
| if CFG.cleanup_hf_cache_between_experts: |
| cleanup_hf_cache() |
|
|
| |
| if os.path.exists(os.path.join(CFG.cache_dir, "image_coco_test")): |
| print("\n [IMAGE-TEST] Cache exists, skipping.") |
| paths["image_test"] = os.path.join(CFG.cache_dir, "image_coco_test") |
| else: |
| print("\n [IMAGE-TEST] COCO test split...") |
| paths["image_test"] = encode_image_expert(bert, split="test", tag="image_coco_test") |
| if CFG.cleanup_hf_cache_between_experts: |
| cleanup_hf_cache() |
|
|
| |
| print(f"\n{'─' * 50}") |
| print("[2/4] AUDIO — Whisper + LibriSpeech") |
| if os.path.exists(os.path.join(CFG.cache_dir, "audio_librispeech")): |
| print(" [AUDIO] Cache exists, skipping.") |
| paths["audio"] = os.path.join(CFG.cache_dir, "audio_librispeech") |
| else: |
| try: |
| paths["audio"] = encode_audio_expert(bert, max_samples=CFG.max_audio_samples) |
| except Exception as e: |
| print(f" AUDIO failed: {e}") |
| paths["audio"] = None |
| if CFG.cleanup_hf_cache_between_experts: |
| cleanup_hf_cache() |
|
|
| |
| print(f"\n{'─' * 50}") |
| print("[3/4] PROTEIN — ESM-2 + Protein2Text-QA") |
| if os.path.exists(os.path.join(CFG.cache_dir, "protein_p2t")): |
| print(" [PROTEIN] Cache exists, skipping.") |
| paths["protein"] = os.path.join(CFG.cache_dir, "protein_p2t") |
| else: |
| try: |
| paths["protein"] = encode_protein_expert(bert, max_samples=CFG.max_protein_samples) |
| except Exception as e: |
| print(f" PROTEIN failed: {e}") |
| paths["protein"] = None |
| if CFG.cleanup_hf_cache_between_experts: |
| cleanup_hf_cache() |
|
|
| |
| print(f"\n{'─' * 50}") |
| print("[4/4] CODE — CodeBERT + CodeSearchNet Python") |
| if os.path.exists(os.path.join(CFG.cache_dir, "code_csn")): |
| print(" [CODE] Cache exists, skipping.") |
| paths["code"] = os.path.join(CFG.cache_dir, "code_csn") |
| else: |
| try: |
| paths["code"] = encode_code_expert(bert, max_samples=CFG.max_code_samples) |
| except Exception as e: |
| print(f" CODE failed: {e}") |
| paths["code"] = None |
| if CFG.cleanup_hf_cache_between_experts: |
| cleanup_hf_cache() |
|
|
| if bert is not None: |
| del bert |
| cleanup_cuda() |
|
|
| flickr_path = os.path.join(CFG.cache_dir, "flickr30k") |
| if os.path.exists(flickr_path): |
| paths["flickr"] = flickr_path |
|
|
| print(f"\n{'=' * 70}") |
| print("CACHE SUMMARY") |
| print(f"{'=' * 70}") |
| for name, path in paths.items(): |
| if path and os.path.exists(path): |
| ds = load_from_disk(path) |
| print(f" {name:15s}: {len(ds):6d} pairs [{path}]") |
|
|
| print("\nReady for Stage 2 multi-expert training.") |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |