mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Add utility to inspect a model's parameters (to get dtype/device)
This commit is contained in:
19
test/test_torch_utils.py
Normal file
19
test/test_torch_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapped", [True, False])
|
||||
def test_get_param(wrapped):
|
||||
mod = torch.nn.Linear(1, 1)
|
||||
cpu = torch.device("cpu")
|
||||
mod.to(dtype=torch.float16, device=cpu)
|
||||
if wrapped:
|
||||
# more or less how spandrel wraps a thing
|
||||
mod = types.SimpleNamespace(model=mod)
|
||||
p = get_param(mod)
|
||||
assert p.dtype == torch.float16
|
||||
assert p.device == cpu
|
Reference in New Issue
Block a user