Add tailored from-scratch training stack (tokenizer + random-init LM) for math conjecture solving
e69a71a verified | #!/usr/bin/env python3 | |
| """Train a math-conjecture model from scratch (tokenizer + random-init LM).""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, Optional, Tuple | |
| import torch | |
| import yaml | |
| from datasets import Dataset, DatasetDict, load_dataset | |
| from huggingface_hub import HfApi | |
| from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, processors, trainers | |
| from transformers import ( | |
| DataCollatorForSeq2Seq, | |
| GPT2Config, | |
| GPT2LMHeadModel, | |
| PreTrainedTokenizerFast, | |
| Trainer, | |
| TrainingArguments, | |
| set_seed, | |
| ) | |
| DEFAULT_CONFIG_PATH = Path("model_development/configs/math_conjecture_scratch.yaml") | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Build tokenizer and train a random-init math-conjecture solver model from scratch." | |
| ) | |
| parser.add_argument("--config", type=Path, default=DEFAULT_CONFIG_PATH, help="YAML config path.") | |
| parser.add_argument("--output-root", type=Path, default=None, help="Override global.output_root.") | |
| parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.") | |
| parser.add_argument("--max-train-samples", type=int, default=None, help="Optional train subset.") | |
| parser.add_argument("--max-eval-samples", type=int, default=None, help="Optional eval subset.") | |
| parser.add_argument("--tokenizer-max-rows", type=int, default=None, help="Override tokenizer.max_train_rows.") | |
| parser.add_argument("--init-only", action="store_true", help="Only build tokenizer/model and save artifacts.") | |
| parser.add_argument("--dry-run", action="store_true", help="Validate pipeline without running training.") | |
| parser.add_argument("--push-to-hub", action="store_true", help="Force Hub push enabled.") | |
| parser.add_argument("--no-push-to-hub", action="store_true", help="Force Hub push disabled.") | |
| parser.add_argument("--credentials-path", type=Path, default=None, help="Override credentials.path.") | |
| return parser.parse_args() | |
| def as_text(value: Any) -> str: | |
| if value is None: | |
| return "" | |
| if isinstance(value, str): | |
| return value.strip() | |
| return str(value).strip() | |
| def as_int(value: Any, default: int) -> int: | |
| if value is None: | |
| return default | |
| try: | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def as_float(value: Any, default: float) -> float: | |
| if value is None: | |
| return default | |
| try: | |
| return float(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def load_config(path: Path) -> Dict[str, Any]: | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Config not found: {path}") | |
| cfg = yaml.safe_load(path.read_text(encoding="utf-8")) | |
| if not isinstance(cfg, dict): | |
| raise ValueError(f"Invalid config format: {path}") | |
| for key in ("global", "tokenizer", "model", "data", "training"): | |
| if key not in cfg or not isinstance(cfg[key], dict): | |
| raise ValueError(f"Config missing section: {key}") | |
| cfg.setdefault("hub", {}) | |
| cfg.setdefault("credentials", {}) | |
| return cfg | |
| def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None: | |
| if args.output_root is not None: | |
| cfg["global"]["output_root"] = str(args.output_root) | |
| if args.max_train_samples is not None: | |
| cfg["data"]["max_train_samples"] = args.max_train_samples | |
| if args.max_eval_samples is not None: | |
| cfg["data"]["max_eval_samples"] = args.max_eval_samples | |
| if args.tokenizer_max_rows is not None: | |
| cfg["tokenizer"]["max_train_rows"] = args.tokenizer_max_rows | |
| if args.repo_id: | |
| cfg.setdefault("hub", {})["repo_id"] = args.repo_id | |
| if args.credentials_path is not None: | |
| cfg.setdefault("credentials", {})["path"] = str(args.credentials_path) | |
| if args.push_to_hub and args.no_push_to_hub: | |
| raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.") | |
| if args.push_to_hub: | |
| cfg.setdefault("hub", {})["push_to_hub"] = True | |
| if args.no_push_to_hub: | |
| cfg.setdefault("hub", {})["push_to_hub"] = False | |
| def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: | |
| token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None | |
| username = as_text(os.environ.get("HF_USERNAME")) or None | |
| cred_path = as_text(cfg.get("credentials", {}).get("path")) | |
| if cred_path: | |
| path = Path(cred_path) | |
| if path.exists(): | |
| data = json.loads(path.read_text(encoding="utf-8")) | |
| if token is None: | |
| token = as_text(data.get("key")) or None | |
| if username is None: | |
| username = as_text(data.get("username")) or None | |
| return token, username | |
| def load_raw_datasets(data_cfg: Dict[str, Any]) -> DatasetDict: | |
| train_path = Path(as_text(data_cfg.get("train_file"))) | |
| valid_path = Path(as_text(data_cfg.get("validation_file"))) | |
| if not train_path.exists(): | |
| raise FileNotFoundError(f"Missing train split: {train_path}") | |
| if not valid_path.exists(): | |
| raise FileNotFoundError(f"Missing validation split: {valid_path}") | |
| splits: Dict[str, Dataset] = {} | |
| files = {"train": str(train_path), "validation": str(valid_path)} | |
| for split_name, split_path in files.items(): | |
| loaded = load_dataset("parquet", data_files={split_name: split_path}) | |
| if split_name in loaded: | |
| splits[split_name] = loaded[split_name] | |
| else: | |
| splits[split_name] = next(iter(loaded.values())) | |
| return DatasetDict(splits) | |
| def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset: | |
| if max_samples is None: | |
| return dataset | |
| if max_samples <= 0: | |
| raise ValueError("max_samples must be positive.") | |
| if max_samples >= len(dataset): | |
| return dataset | |
| return dataset.select(range(max_samples)) | |
| def stringify_structured(value: Any) -> str: | |
| if value is None: | |
| return "" | |
| if isinstance(value, str): | |
| text = value.strip() | |
| if not text: | |
| return "" | |
| try: | |
| parsed = json.loads(text) | |
| except json.JSONDecodeError: | |
| return text | |
| return json.dumps(parsed, ensure_ascii=False, sort_keys=True) | |
| return json.dumps(value, ensure_ascii=False, sort_keys=True) | |
| def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: | |
| prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" | |
| prompt = as_text(row.get(prompt_field)) | |
| if not prompt: | |
| prompt = "Solve the math task." | |
| meta_fields = [ | |
| ("task_type", "Task type"), | |
| ("family", "Family"), | |
| ("difficulty", "Difficulty"), | |
| ("source_dataset", "Source"), | |
| ("status_as_of", "Status as of"), | |
| ] | |
| meta_lines = [] | |
| for key, label in meta_fields: | |
| value = as_text(row.get(key)) | |
| if value: | |
| meta_lines.append(f"{label}: {value}") | |
| if not meta_lines: | |
| return prompt | |
| return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines) | |
| def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: | |
| target_field = as_text(data_cfg.get("target_field")) or "target" | |
| final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" | |
| proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal" | |
| sections = [] | |
| target_text = stringify_structured(row.get(target_field)) | |
| if target_text: | |
| sections.append(f"Structured target:\n{target_text}") | |
| final_answer = stringify_structured(row.get(final_answer_field)) | |
| if final_answer: | |
| sections.append(f"Final answer:\n{final_answer}") | |
| proof_text = stringify_structured(row.get(proof_field)) | |
| if proof_text: | |
| sections.append(f"Formal proof snippet:\n{proof_text}") | |
| if not sections: | |
| sections.append("No structured target provided.") | |
| return "\n\n".join(sections).strip() | |
| def build_prompt_text(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: | |
| system_prompt = as_text(data_cfg.get("system_prompt")) | |
| if not system_prompt: | |
| system_prompt = ( | |
| "You are NorthernTribe Research's math-conjecture solver. " | |
| "Produce rigorous, checkable reasoning." | |
| ) | |
| user_block = build_user_block(row, data_cfg) | |
| return ( | |
| "<|system|>\n" | |
| f"{system_prompt}\n" | |
| "<|user|>\n" | |
| f"{user_block}\n" | |
| "<|assistant|>\n" | |
| ) | |
| def iter_tokenizer_corpus(dataset: Dataset, data_cfg: Dict[str, Any], max_rows: int) -> Iterable[str]: | |
| total = min(max_rows, len(dataset)) | |
| for idx in range(total): | |
| row = dataset[idx] | |
| prompt = build_prompt_text(row, data_cfg) | |
| answer = build_answer_block(row, data_cfg) | |
| yield f"{prompt}{answer}" | |
| def build_tokenizer( | |
| train_dataset: Dataset, | |
| tok_cfg: Dict[str, Any], | |
| data_cfg: Dict[str, Any], | |
| output_root: Path, | |
| ) -> PreTrainedTokenizerFast: | |
| vocab_size = max(2048, as_int(tok_cfg.get("vocab_size"), 32000)) | |
| min_frequency = max(1, as_int(tok_cfg.get("min_frequency"), 2)) | |
| max_rows = max(100, as_int(tok_cfg.get("max_train_rows"), len(train_dataset))) | |
| default_specials = ["<pad>", "<unk>", "<s>", "</s>", "<|system|>", "<|user|>", "<|assistant|>"] | |
| special_tokens_cfg = tok_cfg.get("special_tokens") | |
| if isinstance(special_tokens_cfg, list) and special_tokens_cfg: | |
| special_tokens = [as_text(token) for token in special_tokens_cfg if as_text(token)] | |
| else: | |
| special_tokens = default_specials | |
| for token in default_specials: | |
| if token not in special_tokens: | |
| special_tokens.append(token) | |
| tokenizer_dir_raw = as_text(tok_cfg.get("tokenizer_dir")) | |
| if tokenizer_dir_raw: | |
| tokenizer_dir = Path(tokenizer_dir_raw) | |
| else: | |
| tokenizer_dir = output_root / "tokenizer" | |
| tokenizer_dir.mkdir(parents=True, exist_ok=True) | |
| tokenizer = Tokenizer(models.BPE(unk_token="<unk>")) | |
| tokenizer.normalizer = normalizers.Sequence([normalizers.NFKC()]) | |
| tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) | |
| tokenizer.decoder = decoders.ByteLevel() | |
| trainer = trainers.BpeTrainer( | |
| vocab_size=vocab_size, | |
| min_frequency=min_frequency, | |
| show_progress=True, | |
| special_tokens=special_tokens, | |
| ) | |
| print( | |
| "Training tokenizer from scratch: " | |
| f"rows={min(max_rows, len(train_dataset))} vocab_size={vocab_size} min_frequency={min_frequency}" | |
| ) | |
| tokenizer.train_from_iterator( | |
| iter_tokenizer_corpus(train_dataset, data_cfg, max_rows), | |
| trainer=trainer, | |
| length=min(max_rows, len(train_dataset)), | |
| ) | |
| bos_id = tokenizer.token_to_id("<s>") | |
| eos_id = tokenizer.token_to_id("</s>") | |
| if bos_id is not None and eos_id is not None: | |
| tokenizer.post_processor = processors.TemplateProcessing( | |
| single="<s> $A </s>", | |
| pair="<s> $A </s> <s> $B </s>", | |
| special_tokens=[("<s>", bos_id), ("</s>", eos_id)], | |
| ) | |
| tokenizer_json_path = tokenizer_dir / "tokenizer.json" | |
| tokenizer.save(str(tokenizer_json_path)) | |
| extra_specials = [token for token in special_tokens if token not in {"<pad>", "<unk>", "<s>", "</s>"}] | |
| fast_tokenizer = PreTrainedTokenizerFast( | |
| tokenizer_file=str(tokenizer_json_path), | |
| bos_token="<s>", | |
| eos_token="</s>", | |
| unk_token="<unk>", | |
| pad_token="<pad>", | |
| additional_special_tokens=extra_specials, | |
| ) | |
| fast_tokenizer.model_max_length = max(256, as_int(data_cfg.get("max_seq_length"), 2048)) | |
| fast_tokenizer.save_pretrained(str(tokenizer_dir)) | |
| return fast_tokenizer | |
| def tokenize_datasets(raw: DatasetDict, tokenizer: PreTrainedTokenizerFast, data_cfg: Dict[str, Any]) -> DatasetDict: | |
| max_len = max(128, as_int(data_cfg.get("max_seq_length"), 2048)) | |
| eos = tokenizer.eos_token or "" | |
| remove_columns = raw["train"].column_names | |
| def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]: | |
| prompt_text = build_prompt_text(row, data_cfg) | |
| answer_text = build_answer_block(row, data_cfg) | |
| full_text = f"{prompt_text}{answer_text}{eos}" | |
| prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] | |
| full_enc = tokenizer( | |
| full_text, | |
| add_special_tokens=False, | |
| truncation=True, | |
| max_length=max_len, | |
| ) | |
| input_ids = full_enc["input_ids"] | |
| attention_mask = full_enc["attention_mask"] | |
| if not input_ids: | |
| fallback = tokenizer.eos_token_id | |
| if fallback is None: | |
| fallback = tokenizer.pad_token_id | |
| if fallback is None: | |
| fallback = 0 | |
| return { | |
| "input_ids": [fallback], | |
| "attention_mask": [1], | |
| "labels": [fallback], | |
| } | |
| prompt_len = min(len(prompt_ids), len(input_ids)) | |
| labels = [-100] * prompt_len + input_ids[prompt_len:] | |
| if prompt_len >= len(input_ids): | |
| labels[-1] = input_ids[-1] | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "labels": labels, | |
| } | |
| tokenized = raw.map( | |
| _tokenize, | |
| remove_columns=remove_columns, | |
| desc="Tokenizing prompt/answer pairs", | |
| ) | |
| tokenized = tokenized.filter( | |
| lambda row: any(token != -100 for token in row["labels"]), | |
| desc="Dropping prompt-only rows", | |
| ) | |
| return tokenized | |
| def build_model_from_scratch(model_cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerFast, max_seq_length: int) -> GPT2LMHeadModel: | |
| n_layer = max(2, as_int(model_cfg.get("n_layer"), 12)) | |
| n_head = max(2, as_int(model_cfg.get("n_head"), 12)) | |
| n_embd = max(128, as_int(model_cfg.get("n_embd"), 768)) | |
| if n_embd % n_head != 0: | |
| raise ValueError("model.n_embd must be divisible by model.n_head.") | |
| n_positions = max(max_seq_length, as_int(model_cfg.get("n_positions"), max_seq_length)) | |
| config = GPT2Config( | |
| vocab_size=len(tokenizer), | |
| n_positions=n_positions, | |
| n_ctx=n_positions, | |
| n_embd=n_embd, | |
| n_layer=n_layer, | |
| n_head=n_head, | |
| resid_pdrop=as_float(model_cfg.get("resid_pdrop"), 0.1), | |
| embd_pdrop=as_float(model_cfg.get("embd_pdrop"), 0.1), | |
| attn_pdrop=as_float(model_cfg.get("attn_pdrop"), 0.1), | |
| initializer_range=as_float(model_cfg.get("initializer_range"), 0.02), | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| model = GPT2LMHeadModel(config) | |
| model.config.use_cache = False | |
| return model | |
| def build_training_args(cfg: Dict[str, Any], has_eval_split: bool) -> TrainingArguments: | |
| model_cfg = cfg["model"] | |
| training_cfg = cfg["training"] | |
| use_bf16_requested = bool(model_cfg.get("use_bf16", True)) | |
| cuda_available = torch.cuda.is_available() | |
| bf16 = use_bf16_requested and cuda_available | |
| fp16 = (not use_bf16_requested) and cuda_available | |
| output_dir = Path(as_text(training_cfg.get("output_dir"))) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| max_steps_raw = training_cfg.get("max_steps") | |
| max_steps = as_int(max_steps_raw, -1) if max_steps_raw is not None else -1 | |
| return TrainingArguments( | |
| output_dir=str(output_dir), | |
| num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0), | |
| max_steps=max_steps, | |
| per_device_train_batch_size=max(1, as_int(training_cfg.get("per_device_train_batch_size"), 1)), | |
| per_device_eval_batch_size=max(1, as_int(training_cfg.get("per_device_eval_batch_size"), 1)), | |
| gradient_accumulation_steps=max(1, as_int(training_cfg.get("gradient_accumulation_steps"), 1)), | |
| learning_rate=as_float(training_cfg.get("learning_rate"), 2e-4), | |
| weight_decay=as_float(training_cfg.get("weight_decay"), 0.0), | |
| warmup_ratio=as_float(training_cfg.get("warmup_ratio"), 0.0), | |
| lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine", | |
| max_grad_norm=as_float(training_cfg.get("max_grad_norm"), 1.0), | |
| gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)), | |
| logging_steps=max(1, as_int(training_cfg.get("logging_steps"), 10)), | |
| save_steps=max(1, as_int(training_cfg.get("save_steps"), 250)), | |
| save_total_limit=max(1, as_int(training_cfg.get("save_total_limit"), 3)), | |
| dataloader_num_workers=max(0, as_int(training_cfg.get("dataloader_num_workers"), 0)), | |
| seed=as_int(training_cfg.get("seed"), 17), | |
| bf16=bf16, | |
| fp16=fp16, | |
| remove_unused_columns=False, | |
| report_to="none", | |
| evaluation_strategy="steps" if has_eval_split else "no", | |
| eval_steps=max(1, as_int(training_cfg.get("eval_steps"), 250)) if has_eval_split else None, | |
| ) | |
| def resolve_repo_id(cfg: Dict[str, Any], username: Optional[str]) -> Optional[str]: | |
| repo_id = as_text(cfg.get("hub", {}).get("repo_id")) | |
| if repo_id: | |
| return repo_id | |
| if not username: | |
| return None | |
| output_dir = Path(as_text(cfg["training"].get("output_dir"))) | |
| return f"{username}/{output_dir.name}" | |
| def push_output_to_hub(output_dir: Path, repo_id: str, token: str, private: bool, commit_message: str) -> None: | |
| api = HfApi(token=token) | |
| api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) | |
| api.upload_folder( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| folder_path=str(output_dir), | |
| commit_message=commit_message, | |
| ) | |
| def save_json(path: Path, payload: Dict[str, Any]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps(payload, ensure_ascii=True, indent=2) + "\n", encoding="utf-8") | |
| def main() -> None: | |
| args = parse_args() | |
| cfg = load_config(args.config) | |
| apply_overrides(cfg, args) | |
| global_cfg = cfg["global"] | |
| data_cfg = cfg["data"] | |
| training_cfg = cfg["training"] | |
| output_root = Path(as_text(global_cfg.get("output_root"))) | |
| if not output_root: | |
| raise ValueError("global.output_root is required.") | |
| output_root.mkdir(parents=True, exist_ok=True) | |
| if not as_text(training_cfg.get("output_dir")): | |
| training_cfg["output_dir"] = str(output_root / "checkpoints") | |
| seed = as_int(global_cfg.get("seed"), 17) | |
| training_cfg.setdefault("seed", seed) | |
| set_seed(seed) | |
| token, username = resolve_auth(cfg) | |
| push_to_hub = bool(cfg.get("hub", {}).get("push_to_hub", False)) | |
| if args.dry_run: | |
| push_to_hub = False | |
| repo_id = resolve_repo_id(cfg, username) | |
| if push_to_hub: | |
| if token is None: | |
| raise ValueError("Hub push requested but no token found.") | |
| if repo_id is None: | |
| raise ValueError("Hub push requested but repo_id is empty and username is unavailable.") | |
| raw = load_raw_datasets(data_cfg) | |
| raw["train"] = maybe_select(raw["train"], data_cfg.get("max_train_samples")) | |
| raw["validation"] = maybe_select(raw["validation"], data_cfg.get("max_eval_samples")) | |
| tokenizer = build_tokenizer(raw["train"], cfg["tokenizer"], data_cfg, output_root) | |
| max_seq_length = max(128, as_int(data_cfg.get("max_seq_length"), 2048)) | |
| model = build_model_from_scratch(cfg["model"], tokenizer, max_seq_length) | |
| output_dir = Path(as_text(training_cfg.get("output_dir"))) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| model_size = { | |
| "total_parameters": int(sum(p.numel() for p in model.parameters())), | |
| "trainable_parameters": int(sum(p.numel() for p in model.parameters() if p.requires_grad)), | |
| } | |
| if args.init_only or args.dry_run: | |
| model.save_pretrained(str(output_dir), safe_serialization=True) | |
| tokenizer.save_pretrained(str(output_dir)) | |
| summary = { | |
| "mode": "dry_run" if args.dry_run else "init_only", | |
| "output_dir": str(output_dir), | |
| "tokenizer_dir": str((output_root / "tokenizer").resolve()), | |
| "rows_train": len(raw["train"]), | |
| "rows_validation": len(raw["validation"]), | |
| "max_seq_length": max_seq_length, | |
| "model": model_size, | |
| "config_path": str(args.config), | |
| } | |
| save_json(output_root / "scratch_init_summary.json", summary) | |
| save_json(output_dir / "resolved_training_config.json", cfg) | |
| if push_to_hub and repo_id is not None and token is not None: | |
| commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload scratch-initialized model." | |
| private = bool(cfg.get("hub", {}).get("private", False)) | |
| push_output_to_hub(output_dir, repo_id, token, private, commit_message) | |
| print(f"Pushed model artifacts to https://huggingface.co/{repo_id}") | |
| print(f"Scratch initialization complete. Output saved to: {output_dir}") | |
| return | |
| tokenized = tokenize_datasets(raw, tokenizer, data_cfg) | |
| train_dataset = tokenized["train"] | |
| eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None | |
| training_args = build_training_args(cfg, has_eval_split=eval_dataset is not None) | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer=tokenizer, | |
| model=model, | |
| label_pad_token_id=-100, | |
| pad_to_multiple_of=8, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| train_result = trainer.train() | |
| trainer.log_metrics("train", train_result.metrics) | |
| trainer.save_metrics("train", train_result.metrics) | |
| trainer.save_state() | |
| if eval_dataset is not None: | |
| eval_metrics = trainer.evaluate() | |
| trainer.log_metrics("eval", eval_metrics) | |
| trainer.save_metrics("eval", eval_metrics) | |
| trainer.save_model(training_args.output_dir) | |
| tokenizer.save_pretrained(training_args.output_dir) | |
| save_json(output_dir / "resolved_training_config.json", cfg) | |
| save_json( | |
| output_dir / "scratch_model_summary.json", | |
| { | |
| "output_dir": str(output_dir), | |
| "rows_train": len(train_dataset), | |
| "rows_validation": len(eval_dataset) if eval_dataset is not None else 0, | |
| "max_seq_length": max_seq_length, | |
| "model": model_size, | |
| "config_path": str(args.config), | |
| }, | |
| ) | |
| if push_to_hub and repo_id is not None and token is not None: | |
| commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload scratch-trained model." | |
| private = bool(cfg.get("hub", {}).get("private", False)) | |
| push_output_to_hub(Path(training_args.output_dir), repo_id, token, private, commit_message) | |
| print(f"Pushed model artifacts to https://huggingface.co/{repo_id}") | |
| print(f"Training finished. Output saved to: {training_args.output_dir}") | |
| if __name__ == "__main__": | |
| main() | |