Add UniPC sampler settings

This commit is contained in:
space-nuko
2023-02-10 05:27:05 -08:00
parent c88dcc20d4
commit 79ffb9453f
4 changed files with 16 additions and 3 deletions

View File

@@ -750,7 +750,7 @@ class UniPC:
if method == 'multistep':
assert steps >= order, "UniPC order must be < sampling steps"
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps")
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))