I want to transfer the MT framework to a NLP task but I don't understand how to train it with unlabeled data. I have got the idea of the paper, but i'm confusing about the implementation.
if isinstance(model_out, Variable):
assert args.logit_distance_cost < 0
logit1 = model_out
ema_logit = ema_model_out
else:
assert len(model_out) == 2
assert len(ema_model_out) == 2
logit1, logit2 = model_out
ema_logit, _ = ema_model_out
ema_logit = Variable(ema_logit.detach().data, requires_grad=False)
if args.logit_distance_cost >= 0:
class_logit, cons_logit = logit1, logit2
res_loss = args.logit_distance_cost * residual_logit_criterion(class_logit, cons_logit) / minibatch_size
meters.update('res_loss', res_loss.data[0])
else:
class_logit, cons_logit = logit1, logit1
res_loss = 0
class_loss = class_criterion(class_logit, target_var) / minibatch_size
meters.update('class_loss', class_loss.data[0])
ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size
meters.update('ema_class_loss', ema_class_loss.data[0])
if args.consistency:
consistency_weight = get_current_consistency_weight(epoch)
meters.update('cons_weight', consistency_weight)
consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size
meters.update('cons_loss', consistency_loss.data[0])
else:
consistency_loss = 0
meters.update('cons_loss', 0)
I want to transfer the MT framework to a NLP task but I don't understand how to train it with unlabeled data. I have got the idea of the paper, but i'm confusing about the implementation.
I notice that the
TwoStreamBatchSamplerdivides the dataset into labeled part and unlabeled part, but the code above seems handles both labeled and unlabeled data in a universal way. I think only the labeled part ofmodel_outshould be used to calculate theclass_loss. Did I get it wrong?