Use devices.autocast instead of torch.autocast

This commit is contained in:
brkirch
2022-11-28 21:36:35 -05:00
parent 21effd629d
commit 4d5f1691dd
5 changed files with 6 additions and 11 deletions

View File

@@ -495,7 +495,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.state.interrupted:
break
with torch.autocast("cuda"):
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)