diff --git a/python/needle/backend_ndarray/ndarray.py b/python/needle/backend_ndarray/ndarray.py index 48eb7e1..f66c164 100755 --- a/python/needle/backend_ndarray/ndarray.py +++ b/python/needle/backend_ndarray/ndarray.py @@ -508,7 +508,7 @@ def __matmul__(self, other): def tile(a, tile): return a.as_strided( (a.shape[0] // tile, a.shape[1] // tile, tile, tile), - (a.shape[1] * tile, tile, self.shape[1], 1), + (a.shape[1] * tile, tile, a.shape[1], 1), ) t = self.device.__tile_size__