-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfromScratch_example_ch6_overfit_weight_decay.py
More file actions
73 lines (54 loc) · 1.95 KB
/
Copy pathfromScratch_example_ch6_overfit_weight_decay.py
File metadata and controls
73 lines (54 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# -*- coding: utf-8 -*-
"""
Created on Fri May 28 18:56:58 2021
@author: 이창현
"""
import os
import sys
sys.path.append(os.pardir) # 부모 디렉터리의 파일을 가져올 수 있도록 설정
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.optimizer import SGD
(x_train, t_train),(x_test,t_test) = load_mnist(normalize =True)
# 오버피팅을 재현하기 위해 학습데이터 수를 줄인다.
x_train = x_train[:300]
t_train = t_train[:300]
#가중치 감쇠 설정
weight_decay_lambda = 0.1
network = MultiLayerNet(input_size = 784, hidden_size_list=[100,100,100,100,100,100], output_size=10, weight_decay_lambda=weight_decay_lambda)
#SGD 학습률 설정
optimizer = SGD(lr=0.01)
max_epochs = 201
train_size = x_train.shape[0]
batch_size = 100
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size/batch_size, 1)
epoch_cnt = 0
for i in range(1000000000):
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
grads = network.gradient(x_batch, t_batch)
optimizer.update(network.params, grads)
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train)
test_acc = network.accuracy(x_test, t_test)
test_acc_list.append(test_acc)
train_acc_list.append(train_acc)
print("epoch:" + str(epoch_cnt) +", train acc: " + str(train_acc) + ", test acc: " + str(test_acc))
epoch_cnt += 1
if epoch_cnt >= max_epochs:
break
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, train_acc_list, marker='o', label='train', markevery=10)
plt.plot(x, test_acc_list, marker='s', label='test', markevery=10)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()