diff --git a/.riahub/workflows/train.yaml b/.riahub/workflows/train.yaml index a3c1537..af8db83 100644 --- a/.riahub/workflows/train.yaml +++ b/.riahub/workflows/train.yaml @@ -321,7 +321,7 @@ jobs: # `--device cpu` from the Train step actually takes effect. # No-op if the line already uses args.device (idempotent). if [[ -f main_finetune.py ]]; then - sed -i 's|torch\.amp\.GradScaler(device="cuda")|torch.amp.GradScaler(device=args.device, enabled=(args.device != "cpu"))|' main_finetune.py + sed -i 's|torch\.amp\.GradScaler(device="cuda")|torch.cuda.amp.GradScaler(enabled=(args.device != "cpu"))|' main_finetune.py echo "Patched main_finetune.py GradScaler for CPU/GPU device parity." fi