Use devices.autocast instead of torch.autocast

This commit is contained in:
brkirch
2022-11-28 21:36:35 -05:00
parent 21effd629d
commit 4d5f1691dd
5 changed files with 6 additions and 11 deletions

View File

@@ -316,7 +316,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted:
break
with torch.autocast("cuda"):
with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)