| #!/usr/bin/env bash |
|
|
| |
| BATCH_SIZE=${1:-240} |
| AMP=${2:-true} |
| NUM_EPOCHS=${3:-130} |
| LEARNING_RATE=${4:-0.01} |
| WEIGHT_DECAY=${5:-0.1} |
|
|
| |
| |
| TASK=homo |
|
|
| python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \ |
| se3_transformer.runtime.training \ |
| --amp "$AMP" \ |
| --batch_size "$BATCH_SIZE" \ |
| --epochs "$NUM_EPOCHS" \ |
| --lr "$LEARNING_RATE" \ |
| --min_lr 0.00001 \ |
| --weight_decay "$WEIGHT_DECAY" \ |
| --use_layer_norm \ |
| --norm \ |
| --save_ckpt_path model_qm9.pth \ |
| --precompute_bases \ |
| --seed 42 \ |
| --task "$TASK" |