From be46bd2daca4bb685a1780c09bcce4996b60c97a Mon Sep 17 00:00:00 2001 From: heng Date: Sat, 20 Apr 2019 20:30:01 +0200 Subject: [PATCH] minor changes for pytorch_v1.0 --- train.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index ba44aab..a56e842 100644 --- a/train.py +++ b/train.py @@ -229,9 +229,9 @@ def train(epoch, best_val_loss): loss.backward() optimizer.step() - mse_train.append(F.mse_loss(output, target).data[0]) - nll_train.append(loss_nll.data[0]) - kl_train.append(loss_kl.data[0]) + mse_train.append(F.mse_loss(output, target).item()) + nll_train.append(loss_nll.item()) + kl_train.append(loss_kl.item()) nll_val = [] acc_val = [] @@ -260,9 +260,9 @@ def train(epoch, best_val_loss): acc = edge_accuracy(logits, relations) acc_val.append(acc) - mse_val.append(F.mse_loss(output, target).data[0]) - nll_val.append(loss_nll.data[0]) - kl_val.append(loss_kl.data[0]) + mse_val.append(F.mse_loss(output, target).item()) + nll_val.append(loss_nll.item()) + kl_val.append(loss_kl.item()) print('Epoch: {:04d}'.format(epoch), 'nll_train: {:.10f}'.format(np.mean(nll_train)), @@ -329,9 +329,9 @@ def test(): acc = edge_accuracy(logits, relations) acc_test.append(acc) - mse_test.append(F.mse_loss(output, target).data[0]) - nll_test.append(loss_nll.data[0]) - kl_test.append(loss_kl.data[0]) + mse_test.append(F.mse_loss(output, target).item()) + nll_test.append(loss_nll.item()) + kl_test.append(loss_kl.item()) # For plotting purposes if args.decoder == 'rnn':