Skip to content

[QUESTION]] about EquivariantLayerNormV2 #13

@kzhoa

Description

@kzhoa

Hi, thanks for your wonderful work.
I encountered a question when reading class EquivariantLayerNormV2 in /nets/layer_norm.py .
On computing the field mean with
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]] ,
Should dim here be actually -1 ?
Since we also compute field_norm withdim==-1 in next few lines.

Related codes:

for mul, ir in self.irreps:  # mul is the multiplicity (number of copies) of some irrep type (ir)
            d = ir.dim
            field = node_input.narrow(1, ix, mul*d)
            ix += mul * d

            # [batch * sample, mul, repr]
            field = field.reshape(-1, mul, d)

            # For scalars first compute and subtract the mean
            if ir.l == 0 and ir.p == 1:
                # TODO:  here the dim should be -1?
                field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]
                field = field - field_mean
                
            # Then compute the rescaling factor (norm of each feature vector)
            # Rescaling of the norms themselves based on the option "normalization"
            if self.normalization == 'norm':
                field_norm = field.pow(2).sum(-1)  # [batch * sample, mul]
            elif self.normalization == 'component':
                field_norm = field.pow(2).mean(-1)  # [batch * sample, mul]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions