From a33f239dac83e9313c55dfef9b872769db52bd08 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 25 Jun 2026 07:22:33 -0700 Subject: [PATCH] Avoid fusing ops with only secret operands in linalg-fuse-linear-ops Fixes #3122 PiperOrigin-RevId: 937969405 --- lib/Transforms/LinalgFuseLinearOps/BUILD | 2 + .../LinalgFuseLinearOps.cpp | 75 ++++++++++++++----- .../linalg_fuse_linear_ops.mlir | 42 +++++++++++ 3 files changed, 100 insertions(+), 19 deletions(-) diff --git a/lib/Transforms/LinalgFuseLinearOps/BUILD b/lib/Transforms/LinalgFuseLinearOps/BUILD index ab5b1836ca..a463c43500 100644 --- a/lib/Transforms/LinalgFuseLinearOps/BUILD +++ b/lib/Transforms/LinalgFuseLinearOps/BUILD @@ -14,7 +14,9 @@ cc_library( ], deps = [ ":pass_inc_gen", + "@heir//lib/Analysis/SecretnessAnalysis", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:DestinationStyleOpInterface", "@llvm-project//mlir:DialectUtils", diff --git a/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp index a3c85f979b..dada5da921 100644 --- a/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp +++ b/lib/Transforms/LinalgFuseLinearOps/LinalgFuseLinearOps.cpp @@ -5,10 +5,13 @@ #include #include -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project #include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/include/mlir/IR/AffineMap.h" // from @llvm-project @@ -58,10 +61,15 @@ LogicalResult findLinearOpAndOperand(OpTy op, Operation*& linearOp, } template -LogicalResult fuseScaleOrDivIntoLinearOp(PatternRewriter& rewriter, OpTy op) { +LogicalResult fuseScaleOrDivIntoLinearOp(PatternRewriter& rewriter, OpTy op, + DataFlowSolver& solver) { Operation* linearOp = nullptr; - Value scale_val; - if (failed(findLinearOpAndOperand(op, linearOp, scale_val))) return failure(); + Value scaleVal; + if (failed(findLinearOpAndOperand(op, linearOp, scaleVal))) return failure(); + + if (isSecret(scaleVal, &solver)) { + return failure(); + } Value weights; int64_t weightOperandIdx = -1; @@ -97,7 +105,7 @@ LogicalResult fuseScaleOrDivIntoLinearOp(PatternRewriter& rewriter, OpTy op) { if (!weights) return failure(); auto weightsType = cast(weights.getType()); - auto scaleValType = cast(scale_val.getType()); + auto scaleValType = cast(scaleVal.getType()); if (scaleValType.getRank() != 1) return failure(); @@ -125,7 +133,7 @@ LogicalResult fuseScaleOrDivIntoLinearOp(PatternRewriter& rewriter, OpTy op) { weightsType.getElementType()); auto broadcastOp = linalg::BroadcastOp::create( - rewriter, linearOp->getLoc(), scale_val, emptyOp.getResult(), addedDims); + rewriter, linearOp->getLoc(), scaleVal, emptyOp.getResult(), addedDims); auto scaledWeights = OpTy::create(rewriter, op.getLoc(), weights, broadcastOp.getResults()[0]); @@ -139,11 +147,16 @@ LogicalResult fuseScaleOrDivIntoLinearOp(PatternRewriter& rewriter, OpTy op) { } template -LogicalResult fuseAddOrSubIntoLinearOp(PatternRewriter& rewriter, OpTy op) { +LogicalResult fuseAddOrSubIntoLinearOp(PatternRewriter& rewriter, OpTy op, + DataFlowSolver& solver) { Operation* linearOp = nullptr; Value addend; if (failed(findLinearOpAndOperand(op, linearOp, addend))) return failure(); + if (isSecret(addend, &solver)) { + return failure(); + } + auto destStyleOp = dyn_cast(linearOp); if (!destStyleOp) return failure(); @@ -204,35 +217,51 @@ LogicalResult fuseAddOrSubIntoLinearOp(PatternRewriter& rewriter, OpTy op) { } struct FuseScaleIntoLinearOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FuseScaleIntoLinearOp(MLIRContext* context, DataFlowSolver& solver) + : OpRewritePattern(context), solver(solver) {} LogicalResult matchAndRewrite(arith::MulFOp op, PatternRewriter& rewriter) const override { - return fuseScaleOrDivIntoLinearOp(rewriter, op); + return fuseScaleOrDivIntoLinearOp(rewriter, op, solver); } + + private: + DataFlowSolver& solver; }; struct FuseDivIntoLinearOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FuseDivIntoLinearOp(MLIRContext* context, DataFlowSolver& solver) + : OpRewritePattern(context), solver(solver) {} LogicalResult matchAndRewrite(arith::DivFOp op, PatternRewriter& rewriter) const override { - return fuseScaleOrDivIntoLinearOp(rewriter, op); + return fuseScaleOrDivIntoLinearOp(rewriter, op, solver); } + + private: + DataFlowSolver& solver; }; struct FuseAddIntoLinearOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FuseAddIntoLinearOp(MLIRContext* context, DataFlowSolver& solver) + : OpRewritePattern(context), solver(solver) {} LogicalResult matchAndRewrite(arith::AddFOp op, PatternRewriter& rewriter) const override { - return fuseAddOrSubIntoLinearOp(rewriter, op); + return fuseAddOrSubIntoLinearOp(rewriter, op, solver); } + + private: + DataFlowSolver& solver; }; struct FuseSubIntoLinearOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FuseSubIntoLinearOp(MLIRContext* context, DataFlowSolver& solver) + : OpRewritePattern(context), solver(solver) {} LogicalResult matchAndRewrite(arith::SubFOp op, PatternRewriter& rewriter) const override { - return fuseAddOrSubIntoLinearOp(rewriter, op); + return fuseAddOrSubIntoLinearOp(rewriter, op, solver); } + + private: + DataFlowSolver& solver; }; } // namespace @@ -243,9 +272,17 @@ struct LinalgFuseLinearOps MLIRContext* context = &getContext(); auto module = getOperation(); + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + if (failed(solver.initializeAndRun(module))) { + module->emitOpError() << "Failed to run SecretnessAnalysis.\n"; + return signalPassFailure(); + } + RewritePatternSet patterns(context); patterns.add(context); + FuseAddIntoLinearOp, FuseSubIntoLinearOp>(context, solver); if (failed(applyPatternsGreedily(module, std::move(patterns)))) { return signalPassFailure(); diff --git a/tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir b/tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir index ccf90ee165..7e136c0d16 100644 --- a/tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir +++ b/tests/Transforms/linalg_fuse_linear_ops/linalg_fuse_linear_ops.mlir @@ -99,3 +99,45 @@ func.func @fuse_matmul_with_bias(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %4 = arith.addf %2, %broadcasted_0 : tensor<2x4xf32> return %4 : tensor<2x4xf32> } + +// ----- + +// CHECK: func.func @no_fuse_secret_scale +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2xf32> +// CHECK: %[[VAL_1:.*]] = linalg.matvec ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<2xf32>) +// CHECK: %[[VAL_2:.*]] = arith.mulf %[[VAL_1]], %arg2 : tensor<2xf32> +// CHECK: return %[[VAL_2]] +func.func @no_fuse_secret_scale(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<2xf32> {secret.secret}) -> tensor<2xf32> { + %0 = tensor.empty() : tensor<2xf32> + %1 = linalg.matvec ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%0 : tensor<2xf32>) -> tensor<2xf32> + %2 = arith.mulf %1, %arg2 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// ----- + +// CHECK: func.func @no_fuse_secret_square +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2xf32> +// CHECK: %[[VAL_1:.*]] = linalg.matvec ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<2xf32>) +// CHECK: %[[VAL_2:.*]] = arith.mulf %[[VAL_1]], %[[VAL_1]] : tensor<2xf32> +// CHECK: return %[[VAL_2]] +func.func @no_fuse_secret_square(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32> {secret.secret}) -> tensor<2xf32> { + %0 = tensor.empty() : tensor<2xf32> + %1 = linalg.matvec ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%0 : tensor<2xf32>) -> tensor<2xf32> + %2 = arith.mulf %1, %1 : tensor<2xf32> + return %2 : tensor<2xf32> +} + +// ----- + +// CHECK: func.func @no_fuse_secret_bias +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2xf32> +// CHECK: %[[VAL_1:.*]] = linalg.matvec ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<2xf32>) +// CHECK: %[[VAL_2:.*]] = arith.addf %[[VAL_1]], %arg2 : tensor<2xf32> +// CHECK: return %[[VAL_2]] +func.func @no_fuse_secret_bias(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<2xf32> {secret.secret}) -> tensor<2xf32> { + %0 = tensor.empty() : tensor<2xf32> + %1 = linalg.matvec ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3xf32>) outs(%0 : tensor<2xf32>) -> tensor<2xf32> + %2 = arith.addf %1, %arg2 : tensor<2xf32> + return %2 : tensor<2xf32> +}