| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| import pandas as pd
|
| from transformers import T5Tokenizer
|
| from sentence_transformers import SentenceTransformer
|
|
|
|
|
| INPUT_FILE = "chat_1turn.csv"
|
| EMB_FILE = "chat_embeddings.pt"
|
| MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| EPOCHS = 80
|
| BATCH_SIZE = 16
|
| HIDDEN_DIM = 512
|
| MAX_LEN = 64
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Using device: {device}")
|
|
|
|
|
| df = pd.read_csv(INPUT_FILE)
|
| sources = df["source"].fillna("").tolist()
|
| targets = df["target"].fillna("").tolist()
|
|
|
|
|
| tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| target_enc = tokenizer(targets, padding=True, truncation=True,
|
| return_tensors="pt", max_length=MAX_LEN)
|
| input_ids = target_enc["input_ids"].to(device)
|
| attention_mask = target_enc["attention_mask"].to(device)
|
|
|
|
|
| emb_data = torch.load(EMB_FILE)
|
| x_embeddings = emb_data["source"].to(device)
|
| y_embeddings = emb_data["target"].to(device)
|
|
|
|
|
| class EmbeddingDecoder(nn.Module):
|
| def __init__(self, input_dim, hidden_dim, vocab_size):
|
| super().__init__()
|
| self.bridge = nn.Linear(input_dim, hidden_dim)
|
| self.embed = nn.Embedding(vocab_size, hidden_dim)
|
| self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
| self.fc = nn.Linear(hidden_dim, vocab_size)
|
|
|
| def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN):
|
| hidden = self.bridge(emb_vec).unsqueeze(0)
|
| B = emb_vec.size(0)
|
| outputs = []
|
|
|
|
|
| inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device)
|
|
|
| for t in range(max_len):
|
| inp_emb = self.embed(inp)
|
| out, hidden = self.gru(inp_emb, hidden)
|
| logits = self.fc(out.squeeze(1))
|
| outputs.append(logits.unsqueeze(1))
|
|
|
| if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio:
|
| inp = target_ids[:, t].unsqueeze(1)
|
| else:
|
| inp = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
| return torch.cat(outputs, dim=1)
|
|
|
|
|
| decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device)
|
| optimizer = optim.Adam(decoder.parameters(), lr=1e-3)
|
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
|
|
| print("Training decoder...")
|
| for epoch in range(EPOCHS):
|
| decoder.train()
|
| total_loss = 0.0
|
| for i in range(0, len(y_embeddings), BATCH_SIZE):
|
| xb = y_embeddings[i:i+BATCH_SIZE]
|
| yb = input_ids[i:i+BATCH_SIZE]
|
|
|
| optimizer.zero_grad()
|
| logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1))
|
| loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
|
| loss.backward()
|
| optimizer.step()
|
| total_loss += loss.item()
|
|
|
| print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")
|
|
|
|
|
| embedder = SentenceTransformer(MODEL_NAME, device=device)
|
|
|
| def generate(text, max_len=30, use_mapper=False, mapper=None):
|
| with torch.no_grad():
|
|
|
| emb = embedder.encode([text], convert_to_tensor=True, device=device)
|
| if use_mapper and mapper is not None:
|
| emb = mapper(emb)
|
| logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len)
|
| ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()
|
| return tokenizer.decode(ids, skip_special_tokens=True)
|
|
|
|
|
| print("Hi ->", generate("Hi"))
|
|
|