diff --git a/python/needle/nn/nn_transformer.py b/python/needle/nn/nn_transformer.py index 6cc25ac..dbb67bb 100644 --- a/python/needle/nn/nn_transformer.py +++ b/python/needle/nn/nn_transformer.py @@ -50,22 +50,26 @@ def create_causal_mask(self, i, j, device): def matmul(self, a, b_transpose): """ batched matrix multiplication; - """ - a_shape = (*a.shape[:-1], 1, *a.shape[-1:]) - a = a.reshape(a_shape) + b..., m, k @ b..., k, n -> b..., m, n - b_transpose_shape = (*b_transpose.shape[:-2], 1, *b_transpose.shape[-2:]) - b_transpose = b_transpose.reshape(b_transpose_shape) + b..., m, k, 1 @ + b..., 1, k, n + b..., m, k, n @ + b..., m, k, n + """ + a_shape = (*a.shape, 1) + b_shape = (*b_transpose.shape[:-2], 1, b_transpose.shape[-2], b_transpose.shape[-1]) + a_reshaped = a.reshape(a_shape) + b_reshaped = b_transpose.reshape(b_shape) broadcast_shape = list(a_shape) - broadcast_shape[-2] = b_transpose_shape[-2] - a = a.broadcast_to(broadcast_shape) - - broadcast_shape = list(b_transpose_shape) - broadcast_shape[-3] = a_shape[-3] - b_transpose = b_transpose.broadcast_to(broadcast_shape) - - return (a * b_transpose).sum(len(a.shape) - 1) + broadcast_shape[-1] = b_shape[-1] + a_reshaped = a_reshaped.broadcast_to(broadcast_shape) + b_reshaped = b_reshaped.broadcast_to(broadcast_shape) + out = (a_reshaped * b_reshaped).sum(len(broadcast_shape) - 2) + out_shape = list(a.shape) + out_shape[-1] = b_transpose.shape[-1] + return out.reshape(tuple(out_shape)) def softmax(self, logit): """ @@ -294,4 +298,4 @@ def forward( if not self.batch_first: x = ops.transpose(x, axes=(0, 1)) - return x, init.zeros_like(x) \ No newline at end of file + return x, init.zeros_like(x)