| """ |
| Collins-RoPE 极简 Embedding 模型(HuggingFace 原生实现) |
| 架构:Hash Embedding (2-Universal + Sign Hash) -> RoPE -> Transformer Encoder -> Mean Pooling |
| 目标参数量:~2M |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
| class CollinsConfig(PretrainedConfig): |
| model_type = "collins" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 30522, |
| num_buckets: int = 2048, |
| hidden_size: int = 256, |
| num_hidden_layers: int = 3, |
| num_attention_heads: int = 8, |
| intermediate_size: int = 1024, |
| hidden_dropout_prob: float = 0.1, |
| attention_probs_dropout_prob: float = 0.1, |
| max_position_embeddings: int = 512, |
| |
| hash_seed: int = 42, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.num_buckets = num_buckets |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.intermediate_size = intermediate_size |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| self.max_position_embeddings = max_position_embeddings |
| self.hash_seed = hash_seed |
|
|
|
|
| class CollinsHashEmbedding(nn.Module): |
| """ |
| 2-Universal Hash + Sign Hash 压缩 Embedding。 |
| 哈希参数从 config.hash_seed 确定性生成,保证 save/load 后一致。 |
| """ |
|
|
| def __init__(self, config: CollinsConfig): |
| super().__init__() |
| self.num_buckets = config.num_buckets |
| self.hidden_size = config.hidden_size |
|
|
| self.hash_table = nn.Parameter( |
| torch.randn(config.num_buckets, config.hidden_size) |
| / math.sqrt(config.hidden_size) |
| ) |
|
|
| prime = 2147483647 |
| rng = torch.Generator() |
| rng.manual_seed(config.hash_seed) |
| a1 = torch.randint(1, prime, (1,), generator=rng, dtype=torch.long) |
| b1 = torch.randint(0, prime, (1,), generator=rng, dtype=torch.long) |
| a2 = torch.randint(1, prime, (1,), generator=rng, dtype=torch.long) |
| b2 = torch.randint(0, prime, (1,), generator=rng, dtype=torch.long) |
|
|
| self.register_buffer("prime", torch.tensor(prime, dtype=torch.long)) |
| self.register_buffer("a1", a1) |
| self.register_buffer("b1", b1) |
| self.register_buffer("a2", a2) |
| self.register_buffer("b2", b2) |
|
|
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
| x = input_ids.long() |
| bucket_idx = ((x * self.a1 + self.b1) % self.prime) % self.num_buckets |
| sign = ((x * self.a2 + self.b2) % self.prime) % 2 |
| sign = (sign * 2 - 1).float() |
| return self.hash_table[bucket_idx] * sign.unsqueeze(-1) |
|
|
|
|
| class CollinsModel(PreTrainedModel): |
| """ |
| Collins-RoPE Encoder,输出 last_hidden_state 和 pooler_output。 |
| 使用 transformers.models.bert 的 BertEncoder + RoPE 替换 BertEmbeddings。 |
| """ |
|
|
| config_class = CollinsConfig |
| base_model_prefix = "collins" |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: CollinsConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.embeddings = CollinsHashEmbedding(config) |
|
|
| |
| from transformers.models.bert.modeling_bert import BertEncoder, BertConfig |
|
|
| bert_cfg = BertConfig( |
| hidden_size=config.hidden_size, |
| num_hidden_layers=config.num_hidden_layers, |
| num_attention_heads=config.num_attention_heads, |
| intermediate_size=config.intermediate_size, |
| hidden_dropout_prob=config.hidden_dropout_prob, |
| attention_probs_dropout_prob=config.attention_probs_dropout_prob, |
| max_position_embeddings=config.max_position_embeddings, |
| |
| position_embedding_type="relative_key_query", |
| ) |
| bert_cfg._attn_implementation = "eager" |
| self.encoder = BertEncoder(bert_cfg) |
|
|
| |
| dim = config.hidden_size |
| inv_freq = 1.0 / ( |
| 10000 ** (torch.arange(0, dim, 2).float() / dim) |
| ) |
| t = torch.arange(config.max_position_embeddings).float() |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| self.register_buffer("rope_cos", freqs.cos()) |
| self.register_buffer("rope_sin", freqs.sin()) |
|
|
| self.post_init() |
|
|
| def _apply_rope(self, x: torch.Tensor) -> torch.Tensor: |
| seq_len = x.shape[1] |
| cos = self.rope_cos[:seq_len].unsqueeze(0) |
| sin = self.rope_sin[:seq_len].unsqueeze(0) |
| x1, x2 = x[..., 0::2], x[..., 1::2] |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
| def get_extended_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor: |
| |
| extended = attention_mask[:, None, None, :] |
| extended = (1.0 - extended.float()) * torch.finfo(torch.float32).min |
| return extended |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ): |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids) |
|
|
| x = self.embeddings(input_ids) |
| x = self._apply_rope(x) |
|
|
| ext_mask = self.get_extended_attention_mask(attention_mask) |
| encoder_out = self.encoder(x, attention_mask=ext_mask) |
| hidden_states = encoder_out.last_hidden_state |
|
|
| |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| pooled = F.normalize(pooled, p=2, dim=-1) |
|
|
| if not return_dict: |
| return (hidden_states, pooled) |
|
|
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=None, |
| attentions=None, |
| ), pooled |
|
|
|
|
| class CollinsSTWrapper(nn.Module): |
| """ |
| sentence-transformers 5.x 兼容包装层。 |
| 持有 tokenizer,实现 tokenize() 接口,同时注入 sentence_embedding。 |
| """ |
|
|
| def __init__(self, collins_model: CollinsModel, tokenizer_name_or_path: str = "bert-base-uncased", max_seq_length: int = 128): |
| super().__init__() |
| from transformers import AutoTokenizer |
| self.collins_model = collins_model |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) |
| self.max_seq_length = max_seq_length |
|
|
| def tokenize(self, texts: list[str], padding: str | bool = True) -> dict: |
| return self.tokenizer( |
| texts, |
| padding=padding, |
| truncation=True, |
| max_length=self.max_seq_length, |
| return_tensors="pt", |
| ) |
|
|
| def forward(self, features: dict) -> dict: |
| input_ids = features["input_ids"] |
| attention_mask = features.get("attention_mask", None) |
| _, pooled = self.collins_model(input_ids, attention_mask) |
| features["sentence_embedding"] = pooled |
| return features |
|
|
| def save(self, output_path: str): |
| self.collins_model.save_pretrained(output_path) |
| self.tokenizer.save_pretrained(output_path) |
|
|
| @staticmethod |
| def load(input_path: str) -> "CollinsSTWrapper": |
| model = CollinsModel.from_pretrained(input_path) |
| return CollinsSTWrapper(model, tokenizer_name_or_path=input_path) |
|
|