mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
fix linter issues
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from transformers import BertPreTrainedModel,BertModel,BertConfig
|
||||
from transformers import BertPreTrainedModel,BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
@@ -80,7 +80,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -147,8 +147,8 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
"hidden_states": outputs.hidden_states,
|
||||
"attentions": outputs.attentions,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
# return {
|
||||
# 'pooler_output':pooler_output,
|
||||
# 'last_hidden_state':outputs.last_hidden_state,
|
||||
@@ -161,4 +161,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
||||
config_class= RobertaSeriesConfig
|
||||
|
Reference in New Issue
Block a user