style: formatting

This commit is contained in:
v0xie
2023-10-19 12:52:14 -07:00
parent 321680ccd0
commit d10c4db57e
2 changed files with 2 additions and 37 deletions

View File

@@ -37,7 +37,7 @@ class NetworkModuleOFT(network.NetworkModule):
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, oft_blocks, multiplier=None):
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
@@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule):
output_shape = self.oft_blocks.shape
return self.finalize_updown(updown, orig_weight, output_shape)
def forward(self, x, y=None):
x = self.org_forward(x)
if self.multiplier() == 0.0: