mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
feat(api): return more data for embeddings
This commit is contained in:
@@ -330,9 +330,22 @@ class Api:
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
"sd_checkpoint": embedding.sd_checkpoint,
|
||||
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
||||
"shape": embedding.shape,
|
||||
"vectors": embedding.vectors,
|
||||
}
|
||||
|
||||
def convert_embeddings(embeddings):
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": sorted(db.word_embeddings.keys()),
|
||||
"skipped": sorted(db.skipped_embeddings),
|
||||
"loaded": convert_embeddings(db.word_embeddings),
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
|
@@ -249,6 +249,13 @@ class ArtistItem(BaseModel):
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
Reference in New Issue
Block a user