send tensors to the correct device when loading from safetensors file with memmap disabled for #11260

This commit is contained in:
AUTOMATIC
2023-06-27 09:19:04 +03:00
parent 14196548c5
commit 24129368f1
2 changed files with 4 additions and 2 deletions

View File

@@ -246,11 +246,13 @@ def read_metadata_from_safetensors(filename):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
if not shared.opts.disable_mmap_load_safetensors:
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)