|
|
|
|
|
|
| import torch
|
| from transformers import T5Tokenizer
|
| from sentence_transformers import SentenceTransformer
|
| import torch.nn as nn
|
|
|
|
|
| MAPPER_PTH = "semantic_mapper.pth"
|
| DECODER_PTH = "embedding_decoder.pth"
|
| MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| MAX_LEN = 4096
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| pad_id = tokenizer.pad_token_id
|
| eos_id = tokenizer.eos_token_id
|
|
|
|
|
| class SemanticMapper(torch.nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.net = torch.nn.Sequential(
|
| torch.nn.Linear(dim, dim * 2),
|
| torch.nn.ReLU(),
|
| torch.nn.Linear(dim * 2, dim)
|
| )
|
| def forward(self, x): return self.net(x)
|
|
|
| class EmbeddingDecoder(nn.Module):
|
| def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
|
| super().__init__()
|
| self.bridge = nn.Linear(input_dim, hidden_dim)
|
| self.embed = nn.Embedding(vocab_size, hidden_dim)
|
| self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
|
| self.ln = nn.LayerNorm(hidden_dim)
|
| self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
|
| self.drop = nn.Dropout(p)
|
|
|
| self.fc.weight = self.embed.weight
|
|
|
| @torch.no_grad()
|
| def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
|
| B, _ = emb_vec.shape
|
| h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
| inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
|
| out_ids = []
|
| for _ in range(max_len):
|
| token_h = self.drop(self.embed(inp))
|
| step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
| out, h = self.gru(step_in, h)
|
| out = self.ln(out.squeeze(1))
|
| logits = self.fc(self.drop(out))
|
| logits[:, pad_id] = -1e9
|
| next_id = torch.argmax(logits, dim=-1)
|
| out_ids.append(next_id.unsqueeze(1))
|
| if (next_id == eos_id).all(): break
|
| inp = next_id.unsqueeze(1)
|
| return torch.cat(out_ids, dim=1)
|
|
|
|
|
|
|
| mapper_ckpt = torch.load(MAPPER_PTH, map_location=DEVICE)
|
| mapper = SemanticMapper(mapper_ckpt["dim"]).to(DEVICE)
|
| mapper.load_state_dict(mapper_ckpt["state_dict"])
|
| mapper.eval()
|
|
|
| dec_ckpt = torch.load(DECODER_PTH, map_location=DEVICE)
|
| decoder = EmbeddingDecoder(dec_ckpt["dim"], 512, dec_ckpt["vocab_size"]).to(DEVICE)
|
| decoder.load_state_dict(dec_ckpt["state_dict"])
|
| decoder.eval()
|
|
|
| embedder = SentenceTransformer(MODEL_NAME, device=DEVICE)
|
|
|
|
|
| def chat():
|
| print("Chat ready. Type 'quit' to exit.")
|
| while True:
|
| user = input("User: ").strip()
|
| if not user or user.lower() in {"quit","exit"}: break
|
| x = embedder.encode([user], convert_to_tensor=True, device=DEVICE).detach().clone()
|
| y_pred = mapper(x)
|
| ids = decoder.greedy_decode(y_pred, max_len=MAX_LEN,
|
| start_id=pad_id, eos_id=eos_id)[0].tolist()
|
| reply = tokenizer.decode(ids, skip_special_tokens=True)
|
| print("Bot:", reply)
|
|
|
| if __name__ == "__main__":
|
| chat()
|
|
|