Skip to content

Didn't set model to eval mode when validating in chapter 8 #104

Description

@tianlianghai

In the book and code when we first met resnet, the model has some batch norm layers.
But in validate function, it didn't call model.eval(), I don't think this is true.

image

as a result, in my experiment, while I call model.eval() in the validating function, the result accuracy is lower and loss is higher
image

def val(model, val_loader, loss_fn):
    model.eval()
    losses = []
    correct = 0
    total = 0 
    for x, y in val_loader:
        with torch.no_grad():
            x, y = x.to(device), y.to(device)
            
            out = model(x)
            loss = loss_fn(out, y)
            losses.append(loss.item())
            _, pred = torch.max(out, dim=1)
            correct += int(torch.sum(pred == y))
            total += y.shape[0]
    loss = sum(losses) / len(losses)
    acc = correct/total
    print(f"val loss:{loss:.3f}, acc:{acc:.3f}")
    model.train()

if I add the model.val() line in the code , the validation accuracy even goes done during training, almost always 50%
image

while I don't add the model.eval() line the test accuracy actually reached about 85%.
Shouldn't we set model.eval() in validating?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions