|
|
import argparse |
|
|
import os |
|
|
import torch |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from scripts.core.training.model import CodeEmbedder |
|
|
from scripts.core.training.trainer import CodeTrainer |
|
|
|
|
|
import json |
|
|
|
|
|
|
|
|
class RealCodeDataset(Dataset): |
|
|
def __init__(self, jsonl_path, tokenizer, max_length=512): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = [] |
|
|
|
|
|
print(f"Loading data from {jsonl_path}...") |
|
|
with open(jsonl_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
self.data.append(json.loads(line)) |
|
|
print(f"Loaded {len(self.data)} triplets.") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.data[idx] |
|
|
|
|
|
|
|
|
def tokenize_text(text): |
|
|
return self.tokenizer( |
|
|
text, |
|
|
return_tensors='pt', |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=self.max_length |
|
|
) |
|
|
|
|
|
|
|
|
anchor = tokenize_text(item['anchor']) |
|
|
positive = tokenize_text(item['positive']) |
|
|
negative = tokenize_text(item['negative']) |
|
|
|
|
|
|
|
|
return { |
|
|
'anchor_input_ids': anchor['input_ids'].squeeze(0), |
|
|
'anchor_attention_mask': anchor['attention_mask'].squeeze(0), |
|
|
'positive_input_ids': positive['input_ids'].squeeze(0), |
|
|
'positive_attention_mask': positive['attention_mask'].squeeze(0), |
|
|
'negative_input_ids': negative['input_ids'].squeeze(0), |
|
|
'negative_attention_mask': negative['attention_mask'].squeeze(0) |
|
|
} |
|
|
|
|
|
|
|
|
class DummyCodeDataset(Dataset): |
|
|
def __init__(self, tokenizer, size=100): |
|
|
self.tokenizer = tokenizer |
|
|
self.size = size |
|
|
|
|
|
self.data = [{"anchor": "def hello(): return 'world'", "positive": "def hi(): return 'earth'", "negative": "class Foo: pass"}] * size |
|
|
|
|
|
def __len__(self): |
|
|
return self.size |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.data[idx] |
|
|
|
|
|
|
|
|
def tokenize_text(text): |
|
|
return self.tokenizer( |
|
|
text, |
|
|
return_tensors='pt', |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=128 |
|
|
) |
|
|
|
|
|
anchor = tokenize_text(item['anchor']) |
|
|
positive = tokenize_text(item['positive']) |
|
|
negative = tokenize_text(item['negative']) |
|
|
|
|
|
return { |
|
|
'anchor_input_ids': anchor['input_ids'].squeeze(0), |
|
|
'anchor_attention_mask': anchor['attention_mask'].squeeze(0), |
|
|
'positive_input_ids': positive['input_ids'].squeeze(0), |
|
|
'positive_attention_mask': positive['attention_mask'].squeeze(0), |
|
|
'negative_input_ids': negative['input_ids'].squeeze(0), |
|
|
'negative_attention_mask': negative['attention_mask'].squeeze(0) |
|
|
} |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train CodeMode Embeddings") |
|
|
|
|
|
parser.add_argument("--model_name", type=str, default="microsoft/codebert-base", help="Hub model name") |
|
|
parser.add_argument("--data_path", type=str, required=False, help="Path to parsed chunks.jsonl") |
|
|
parser.add_argument("--output_dir", type=str, default="./output", help="Where to save checkpoints") |
|
|
parser.add_argument("--epochs", type=int, default=3) |
|
|
parser.add_argument("--batch_size", type=int, default=8) |
|
|
parser.add_argument("--accumulation_steps", type=int, default=4, help="Gradient Accumulation Steps") |
|
|
parser.add_argument("--lr", type=float, default=2e-5) |
|
|
parser.add_argument("--dry_run", action="store_true", help="Run with dummy data for 1 epoch") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Initializing Training Pipeline...") |
|
|
print(f" Model: {args.model_name}") |
|
|
print(f" Output: {args.output_dir}") |
|
|
print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
|
|
|
|
|
|
if args.data_path and os.path.exists(args.data_path): |
|
|
train_dataset = RealCodeDataset(args.data_path, tokenizer) |
|
|
else: |
|
|
print("No data path provided or file missing. Using DUMMY data for verification.") |
|
|
train_dataset = DummyCodeDataset(tokenizer, size=100) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
|
|
|
|
|
|
|
|
model = CodeEmbedder(model_name_or_path=args.model_name) |
|
|
|
|
|
|
|
|
trainer = CodeTrainer( |
|
|
model=model, |
|
|
train_loader=train_loader, |
|
|
epochs=args.epochs, |
|
|
learning_rate=args.lr, |
|
|
accumulation_steps=args.accumulation_steps, |
|
|
mixed_precision=True, |
|
|
output_dir=args.output_dir |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
print("Training Complete.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|