-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
100 lines (84 loc) · 4.08 KB
/
main.py
File metadata and controls
100 lines (84 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import time
import random
import numpy as np
import scipy, multiprocessing
from tensorflow import keras
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import (Input, Conv2d, BatchNorm2d, Elementwise, SubpixelConv2d, Flatten, Dense)
from tensorlayer.models import Model
from PIL import Image
def get_G(input_shape):
w_init = tf.random_normal_initializer(stddev=0.02)
g_init = tf.random_normal_initializer(1., 0.02)
nin = Input(input_shape)
n = Conv2d(64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init)(nin)
temp = n
# B residual blocks
for i in range(16):
nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
nn = BatchNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(nn)
nn = BatchNorm2d(gamma_init=g_init)(nn)
nn = Elementwise(tf.add)([n, nn])
n = nn
n = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(gamma_init=g_init)(n)
n = Elementwise(tf.add)([n, temp])
# B residual blacks end
n = Conv2d(256, (3, 3), (1, 1), padding='SAME', W_init=w_init)(n)
n = SubpixelConv2d(scale=2, n_out_channels=None, act=tf.nn.relu)(n)
n = Conv2d(256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init)(n)
n = SubpixelConv2d(scale=2, n_out_channels=None, act=tf.nn.relu)(n)
nn = Conv2d(3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init)(n)
G = Model(inputs=nin, outputs=nn, name="generator")
return G
def get_D(input_shape):
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)
df_dim = 64
lrelu = lambda x: tl.act.lrelu(x, 0.2)
nin = Input(input_shape)
n = Conv2d(df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init)(nin)
n = Conv2d(df_dim * 2, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 4, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 8, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 16, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 32, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 16, (1, 1), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 8, (1, 1), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
nn = BatchNorm2d(gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 2, (1, 1), (1, 1), padding='SAME', W_init=w_init, b_init=None)(nn)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 2, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(act=lrelu, gamma_init=gamma_init)(n)
n = Conv2d(df_dim * 8, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
n = BatchNorm2d(gamma_init=gamma_init)(n)
n = Elementwise(combine_fn=tf.add, act=lrelu)([n, nn])
n = Flatten()(n)
no = Dense(n_units=1, W_init=w_init)(n)
D = Model(inputs=nin, outputs=no, name="discriminator")
return D
def cropImages():
# Opens a image in RGB mode
im_hr = Image.open(r"./samples/valid_gen.png")
im_lr = Image.open(r"./samples/valid_lr.png")
# Setting the points for cropped image
left = 650
top = 120
right = 1150
bottom = 500
# Cropped image of above dimension
# (It will not change original image)
im1 = im_hr.crop((left, top, right, bottom))
im2 = im_lr.crop((left/4, top/4, right/4, bottom/4))
# Shows the image in image viewer
im1.save("./samples/valid_gen_crop.png")
im2.save("./samples/valid_lr_crop.png")