Add utility to inspect a model's parameters (to get dtype/device)

This commit is contained in:
Aarni Koskela
2023-12-31 00:20:30 +02:00
parent a84e842189
commit 5768afc776
8 changed files with 53 additions and 7 deletions

17
modules/torch_utils.py Normal file
View File

@@ -0,0 +1,17 @@
from __future__ import annotations
import torch.nn
def get_param(model) -> torch.nn.Parameter:
"""
Find the first parameter in a model or module.
"""
if hasattr(model, "model") and hasattr(model.model, "parameters"):
# Unpeel a model descriptor to get at the actual Torch module.
model = model.model
for param in model.parameters():
return param
raise ValueError(f"No parameters found in model {model!r}")