Use GPU for loading safetensors, disable export

This commit is contained in:
Tim Patton
2022-11-21 16:40:18 -05:00
parent e134b74ce9
commit 210cb4c128
2 changed files with 5 additions and 3 deletions

View File

@@ -147,8 +147,9 @@ def torch_load(model_filename, model_info, map_override=None):
map_override=shared.weight_load_location if not map_override else map_override
if(checkpoint_types[model_info.exttype] == 'safetensors'):
# safely load weights
# TODO: safetensors supports zero copy fast load to gpu, see issue #684
return load_file(model_filename, device=map_override)
# TODO: safetensors supports zero copy fast load to gpu, see issue #684.
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95
return load_file(model_filename, device='cuda')
else:
return torch.load(model_filename, map_location=map_override)