forked from qoherent/icc-demo
Workflow: use torch.cuda.amp.GradScaler (torch 2.2.x API)
This commit is contained in:
parent
9dd29ac589
commit
9cb3f35225
|
|
@ -321,7 +321,7 @@ jobs:
|
|||
# `--device cpu` from the Train step actually takes effect.
|
||||
# No-op if the line already uses args.device (idempotent).
|
||||
if [[ -f main_finetune.py ]]; then
|
||||
sed -i 's|torch\.amp\.GradScaler(device="cuda")|torch.amp.GradScaler(device=args.device, enabled=(args.device != "cpu"))|' main_finetune.py
|
||||
sed -i 's|torch\.amp\.GradScaler(device="cuda")|torch.cuda.amp.GradScaler(enabled=(args.device != "cpu"))|' main_finetune.py
|
||||
echo "Patched main_finetune.py GradScaler for CPU/GPU device parity."
|
||||
fi
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user