More informative progress printing

This commit is contained in:
JohannesGaessler
2022-09-08 15:37:13 +02:00
parent ad02b249f5
commit f211c498b9
5 changed files with 43 additions and 2 deletions

View File

@@ -70,13 +70,14 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
state.sampling_step = 0
for x in tqdm.tqdm(sequence, *args, desc=state.job, **kwargs):
for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
if state.interrupted:
break
yield x
state.sampling_step += 1
shared.total_tqdm.update()
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
@@ -146,13 +147,14 @@ def extended_trange(count, *args, **kwargs):
state.sampling_steps = count
state.sampling_step = 0
for x in tqdm.trange(count, *args, desc=state.job, **kwargs):
for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
if state.interrupted:
break
yield x
state.sampling_step += 1
shared.total_tqdm.update()
class KDiffusionSampler: