ruff manual fixes

This commit is contained in:
AUTOMATIC
2023-05-10 11:19:16 +03:00
parent 028d3f6425
commit 550256db1c
15 changed files with 69 additions and 41 deletions

View File

@@ -178,13 +178,13 @@ def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
model_kwargs=None,
guidance_type="uncond",
#condition=None,
#unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
classifier_kwargs={},
classifier_kwargs=None,
):
"""Create a wrapper function for the noise prediction model.
@@ -275,6 +275,9 @@ def model_wrapper(
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
model_kwargs = model_kwargs or []
classifier_kwargs = classifier_kwargs or []
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.