This commit is contained in:
wangshuai09
2024-01-31 10:46:53 +08:00
parent 74ff85a1a1
commit cc3f604310
8 changed files with 22 additions and 20 deletions

View File

@@ -150,10 +150,7 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
# workaround
if devices.npu_specific.has_npu:
import torch
torch.npu.set_device(0)
devices.torch_npu_set_device()
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]