Skip to content

How to train the model with unlabeled data? #59

@broken-dream

Description

@broken-dream

I want to transfer the MT framework to a NLP task but I don't understand how to train it with unlabeled data. I have got the idea of the paper, but i'm confusing about the implementation.

    if isinstance(model_out, Variable):
        assert args.logit_distance_cost < 0
        logit1 = model_out
        ema_logit = ema_model_out
    else:
        assert len(model_out) == 2
        assert len(ema_model_out) == 2
        logit1, logit2 = model_out
        ema_logit, _ = ema_model_out

    ema_logit = Variable(ema_logit.detach().data, requires_grad=False)

    if args.logit_distance_cost >= 0:
        class_logit, cons_logit = logit1, logit2
        res_loss = args.logit_distance_cost * residual_logit_criterion(class_logit, cons_logit) / minibatch_size
        meters.update('res_loss', res_loss.data[0])
    else:
        class_logit, cons_logit = logit1, logit1
        res_loss = 0

    class_loss = class_criterion(class_logit, target_var) / minibatch_size
    meters.update('class_loss', class_loss.data[0])

    ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size
    meters.update('ema_class_loss', ema_class_loss.data[0])

    if args.consistency:
        consistency_weight = get_current_consistency_weight(epoch)
        meters.update('cons_weight', consistency_weight)
        consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size
        meters.update('cons_loss', consistency_loss.data[0])
    else:
        consistency_loss = 0
        meters.update('cons_loss', 0)

I notice that the TwoStreamBatchSampler divides the dataset into labeled part and unlabeled part, but the code above seems handles both labeled and unlabeled data in a universal way. I think only the labeled part of model_out should be used to calculate the class_loss. Did I get it wrong?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions