diff --git a/pg_modules/networks_fastgan.py b/pg_modules/networks_fastgan.py index 1a32056..e96f9b9 100644 --- a/pg_modules/networks_fastgan.py +++ b/pg_modules/networks_fastgan.py @@ -3,7 +3,15 @@ # modified by Axel Sauer for "Projected GANs Converge Faster" # import torch.nn as nn -from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d) +from pg_modules.blocks import ( + InitLayer, + UpBlockBig, + UpBlockBigCond, + UpBlockSmall, + UpBlockSmallCond, + SEBlock, + conv2d, +) def normalize_second_moment(x, dim=1, eps=1e-8): @@ -25,25 +33,35 @@ def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False): self.z_dim = z_dim # channel multiplier - nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, - 512:0.25, 1024:0.125} + nfc_multi = { + 2: 16, + 4: 16, + 8: 8, + 16: 4, + 32: 2, + 64: 2, + 128: 1, + 256: 0.5, + 512: 0.25, + 1024: 0.125, + } nfc = {} for k, v in nfc_multi.items(): - nfc[k] = int(v*ngf) + nfc[k] = int(v * ngf) # layers self.init = InitLayer(z_dim, channel=nfc[2], sz=4) UpBlock = UpBlockSmall if lite else UpBlockBig - self.feat_8 = UpBlock(nfc[4], nfc[8]) - self.feat_16 = UpBlock(nfc[8], nfc[16]) - self.feat_32 = UpBlock(nfc[16], nfc[32]) - self.feat_64 = UpBlock(nfc[32], nfc[64]) + self.feat_8 = UpBlock(nfc[4], nfc[8]) + self.feat_16 = UpBlock(nfc[8], nfc[16]) + self.feat_32 = UpBlock(nfc[16], nfc[32]) + self.feat_64 = UpBlock(nfc[32], nfc[64]) self.feat_128 = UpBlock(nfc[64], nfc[128]) self.feat_256 = UpBlock(nfc[128], nfc[256]) - self.se_64 = SEBlock(nfc[4], nfc[64]) + self.se_64 = SEBlock(nfc[4], nfc[64]) self.se_128 = SEBlock(nfc[8], nfc[128]) self.se_256 = SEBlock(nfc[16], nfc[256]) @@ -64,7 +82,7 @@ def forward(self, input, c, **kwargs): feat_16 = self.feat_16(feat_8) feat_32 = self.feat_32(feat_16) feat_64 = self.se_64(feat_4, self.feat_64(feat_32)) - feat_128 = self.se_128(feat_8, self.feat_128(feat_64)) + feat_128 = self.se_128(feat_8, self.feat_128(feat_64)) if self.img_resolution >= 128: feat_last = feat_128 @@ -82,15 +100,28 @@ def forward(self, input, c, **kwargs): class FastganSynthesisCond(nn.Module): - def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False): + def __init__( + self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False + ): super().__init__() self.z_dim = z_dim - nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, - 512:0.25, 1024:0.125, 2048:0.125} + nfc_multi = { + 2: 16, + 4: 16, + 8: 8, + 16: 4, + 32: 2, + 64: 2, + 128: 1, + 256: 0.5, + 512: 0.25, + 1024: 0.125, + 2048: 0.125, + } nfc = {} for k, v in nfc_multi.items(): - nfc[k] = int(v*ngf) + nfc[k] = int(v * ngf) self.img_resolution = img_resolution @@ -98,10 +129,10 @@ def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000 UpBlock = UpBlockSmallCond if lite else UpBlockBigCond - self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim) - self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim) - self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim) - self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim) + self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim) + self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim) + self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim) + self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim) self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim) self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim) @@ -112,10 +143,10 @@ def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000 self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) if img_resolution > 256: - self.feat_512 = UpBlock(nfc[256], nfc[512]) + self.feat_512 = UpBlock(nfc[256], nfc[512], z_dim) self.se_512 = SEBlock(nfc[32], nfc[512]) if img_resolution > 512: - self.feat_1024 = UpBlock(nfc[512], nfc[1024]) + self.feat_1024 = UpBlock(nfc[512], nfc[1024], z_dim) self.embed = nn.Embedding(num_classes, z_dim) @@ -130,7 +161,7 @@ def forward(self, input, c, update_emas=False): feat_16 = self.feat_16(feat_8, c) feat_32 = self.feat_32(feat_16, c) feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c)) - feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c)) + feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c)) if self.img_resolution >= 128: feat_last = feat_128 @@ -158,7 +189,7 @@ def __init__( ngf=128, cond=0, mapping_kwargs={}, - synthesis_kwargs={} + synthesis_kwargs={}, ): super().__init__() self.z_dim = z_dim @@ -170,7 +201,13 @@ def __init__( # Mapping and Synthesis Networks self.mapping = DummyMapping() # to fit the StyleGAN API Synthesis = FastganSynthesisCond if cond else FastganSynthesis - self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs) + self.synthesis = Synthesis( + ngf=ngf, + z_dim=z_dim, + nc=img_channels, + img_resolution=img_resolution, + **synthesis_kwargs + ) def forward(self, z, c, **kwargs): w = self.mapping(z, c)