Workflow: use torch.cuda.amp.GradScaler (torch 2.2.x API)

This commit is contained in:
P Roman Pope 2026-05-28 02:59:48 -04:00
parent 9dd29ac589
commit 9cb3f35225

View File

@ -321,7 +321,7 @@ jobs:
# `--device cpu` from the Train step actually takes effect.
# No-op if the line already uses args.device (idempotent).
if [[ -f main_finetune.py ]]; then
sed -i 's|torch\.amp\.GradScaler(device="cuda")|torch.amp.GradScaler(device=args.device, enabled=(args.device != "cpu"))|' main_finetune.py
sed -i 's|torch\.amp\.GradScaler(device="cuda")|torch.cuda.amp.GradScaler(enabled=(args.device != "cpu"))|' main_finetune.py
echo "Patched main_finetune.py GradScaler for CPU/GPU device parity."
fi