Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 60 additions & 23 deletions pg_modules/networks_fastgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
# modified by Axel Sauer for "Projected GANs Converge Faster"
#
import torch.nn as nn
from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d)
from pg_modules.blocks import (
InitLayer,
UpBlockBig,
UpBlockBigCond,
UpBlockSmall,
UpBlockSmallCond,
SEBlock,
conv2d,
)


def normalize_second_moment(x, dim=1, eps=1e-8):
Expand All @@ -25,25 +33,35 @@ def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False):
self.z_dim = z_dim

# channel multiplier
nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
512:0.25, 1024:0.125}
nfc_multi = {
2: 16,
4: 16,
8: 8,
16: 4,
32: 2,
64: 2,
128: 1,
256: 0.5,
512: 0.25,
1024: 0.125,
}
nfc = {}
for k, v in nfc_multi.items():
nfc[k] = int(v*ngf)
nfc[k] = int(v * ngf)

# layers
self.init = InitLayer(z_dim, channel=nfc[2], sz=4)

UpBlock = UpBlockSmall if lite else UpBlockBig

self.feat_8 = UpBlock(nfc[4], nfc[8])
self.feat_16 = UpBlock(nfc[8], nfc[16])
self.feat_32 = UpBlock(nfc[16], nfc[32])
self.feat_64 = UpBlock(nfc[32], nfc[64])
self.feat_8 = UpBlock(nfc[4], nfc[8])
self.feat_16 = UpBlock(nfc[8], nfc[16])
self.feat_32 = UpBlock(nfc[16], nfc[32])
self.feat_64 = UpBlock(nfc[32], nfc[64])
self.feat_128 = UpBlock(nfc[64], nfc[128])
self.feat_256 = UpBlock(nfc[128], nfc[256])

self.se_64 = SEBlock(nfc[4], nfc[64])
self.se_64 = SEBlock(nfc[4], nfc[64])
self.se_128 = SEBlock(nfc[8], nfc[128])
self.se_256 = SEBlock(nfc[16], nfc[256])

Expand All @@ -64,7 +82,7 @@ def forward(self, input, c, **kwargs):
feat_16 = self.feat_16(feat_8)
feat_32 = self.feat_32(feat_16)
feat_64 = self.se_64(feat_4, self.feat_64(feat_32))
feat_128 = self.se_128(feat_8, self.feat_128(feat_64))
feat_128 = self.se_128(feat_8, self.feat_128(feat_64))

if self.img_resolution >= 128:
feat_last = feat_128
Expand All @@ -82,26 +100,39 @@ def forward(self, input, c, **kwargs):


class FastganSynthesisCond(nn.Module):
def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False):
def __init__(
self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False
):
super().__init__()

self.z_dim = z_dim
nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
512:0.25, 1024:0.125, 2048:0.125}
nfc_multi = {
2: 16,
4: 16,
8: 8,
16: 4,
32: 2,
64: 2,
128: 1,
256: 0.5,
512: 0.25,
1024: 0.125,
2048: 0.125,
}
nfc = {}
for k, v in nfc_multi.items():
nfc[k] = int(v*ngf)
nfc[k] = int(v * ngf)

self.img_resolution = img_resolution

self.init = InitLayer(z_dim, channel=nfc[2], sz=4)

UpBlock = UpBlockSmallCond if lite else UpBlockBigCond

self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim)
self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim)
self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim)
self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim)
self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim)
self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim)
self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim)
self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim)
self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim)
self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim)

Expand All @@ -112,10 +143,10 @@ def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000
self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)

if img_resolution > 256:
self.feat_512 = UpBlock(nfc[256], nfc[512])
self.feat_512 = UpBlock(nfc[256], nfc[512], z_dim)
self.se_512 = SEBlock(nfc[32], nfc[512])
if img_resolution > 512:
self.feat_1024 = UpBlock(nfc[512], nfc[1024])
self.feat_1024 = UpBlock(nfc[512], nfc[1024], z_dim)

self.embed = nn.Embedding(num_classes, z_dim)

Expand All @@ -130,7 +161,7 @@ def forward(self, input, c, update_emas=False):
feat_16 = self.feat_16(feat_8, c)
feat_32 = self.feat_32(feat_16, c)
feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c))
feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c))
feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c))

if self.img_resolution >= 128:
feat_last = feat_128
Expand Down Expand Up @@ -158,7 +189,7 @@ def __init__(
ngf=128,
cond=0,
mapping_kwargs={},
synthesis_kwargs={}
synthesis_kwargs={},
):
super().__init__()
self.z_dim = z_dim
Expand All @@ -170,7 +201,13 @@ def __init__(
# Mapping and Synthesis Networks
self.mapping = DummyMapping() # to fit the StyleGAN API
Synthesis = FastganSynthesisCond if cond else FastganSynthesis
self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs)
self.synthesis = Synthesis(
ngf=ngf,
z_dim=z_dim,
nc=img_channels,
img_resolution=img_resolution,
**synthesis_kwargs
)

def forward(self, z, c, **kwargs):
w = self.mapping(z, c)
Expand Down