From 03b7ba7f5e489eeff96e6f1777d643d9a491398b Mon Sep 17 00:00:00 2001 From: Panther Date: Sat, 24 Jan 2026 17:35:50 +0530 Subject: [PATCH 1/2] fix(Secretize): selective output wrapping in WrapGeneric using SecretnessAnalysis --- lib/Transforms/Secretize/BUILD | 2 + lib/Transforms/Secretize/WrapGeneric.cpp | 264 ++++++++++++++---- .../Transforms/wrap_generic/wrap_generic.mlir | 17 ++ 3 files changed, 231 insertions(+), 52 deletions(-) diff --git a/lib/Transforms/Secretize/BUILD b/lib/Transforms/Secretize/BUILD index d10bdc4710..2dbffabbd2 100644 --- a/lib/Transforms/Secretize/BUILD +++ b/lib/Transforms/Secretize/BUILD @@ -17,8 +17,10 @@ cc_library( ], deps = [ ":pass_inc_gen", + "@heir//lib/Analysis/SecretnessAnalysis", "@heir//lib/Dialect/Secret/IR:Dialect", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/lib/Transforms/Secretize/WrapGeneric.cpp b/lib/Transforms/Secretize/WrapGeneric.cpp index eaa4e7bd54..04d25f2204 100644 --- a/lib/Transforms/Secretize/WrapGeneric.cpp +++ b/lib/Transforms/Secretize/WrapGeneric.cpp @@ -1,25 +1,28 @@ #include +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Dialect/Secret/IR/SecretDialect.h" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Dialect/Secret/IR/SecretTypes.h" #include "lib/Transforms/Secretize/Passes.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/Block.h" // from @llvm-project -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project -#include "mlir/include/mlir/IR/Location.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Block.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project #include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project namespace mlir { @@ -29,8 +32,8 @@ namespace heir { #include "lib/Transforms/Secretize/Passes.h.inc" struct WrapWithGeneric : public OpRewritePattern { - WrapWithGeneric(mlir::MLIRContext* context) - : mlir::OpRewritePattern(context) {} + WrapWithGeneric(mlir::MLIRContext* context, DataFlowSolver* solver) + : mlir::OpRewritePattern(context), solver(solver) {} LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter& rewriter) const override { @@ -58,54 +61,203 @@ struct WrapWithGeneric : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "no secret inputs found"); } - auto newOutputs = llvm::to_vector<6>(llvm::map_range( - op.getResultTypes(), - [](Type t) -> Type { return secret::SecretType::get(t); })); + // Externally defined functions have no body - conservatively wrap all + // outputs + if (op.isDeclaration()) { + SmallVector newOutputs; + for (Type resultType : op.getResultTypes()) { + newOutputs.push_back(secret::SecretType::get(resultType)); + } + rewriter.modifyOpInPlace(op, [&] { + op.setFunctionType( + FunctionType::get(getContext(), {newInputs}, {newOutputs})); + }); + return success(); + } + + // Phase 1: Identify which operations depend on secrets + Block& opEntryBlock = op.getRegion().front(); + auto* returnOp = opEntryBlock.getTerminator(); + + // Track which values are secret (including block arguments) + llvm::DenseSet secretValues; + for (unsigned i = 0; i < op.getNumArguments(); i++) { + if (isSecret(op.getArgument(i), solver)) { + secretValues.insert(op.getArgument(i)); + } + } + + // Track which operations are secret-dependent + llvm::DenseSet secretOps; + for (Operation& bodyOp : opEntryBlock) { + if (&bodyOp == returnOp) continue; - // modification to function type should go through the rewriter + // An operation is secret if any of its operands are secret + bool isSecretOp = llvm::any_of(bodyOp.getOperands(), [&](Value operand) { + return secretValues.contains(operand) || isSecret(operand, solver); + }); + + if (isSecretOp) { + secretOps.insert(&bodyOp); + // All results of a secret op become secret + for (Value result : bodyOp.getResults()) { + secretValues.insert(result); + } + } + } + + // Phase 2: Determine output types and which outputs need to be in generic + SmallVector newOutputs; + SmallVector secretReturnValues; + SmallVector plaintextReturnValues; + SmallVector secretReturnIndices; + SmallVector plaintextReturnIndices; + + for (auto [i, resultType] : llvm::enumerate(op.getResultTypes())) { + Value returnVal = returnOp->getOperand(i); + if (secretValues.contains(returnVal) || isSecret(returnVal, solver)) { + newOutputs.push_back(secret::SecretType::get(resultType)); + secretReturnValues.push_back(returnVal); + secretReturnIndices.push_back(i); + } else { + newOutputs.push_back(resultType); + plaintextReturnValues.push_back(returnVal); + plaintextReturnIndices.push_back(i); + } + } + + // Modification to function type should go through the rewriter rewriter.modifyOpInPlace(op, [&] { op.setFunctionType( FunctionType::get(getContext(), {newInputs}, {newOutputs})); }); - // Externally defined functions have no body - if (op.isDeclaration()) { + // If there are no secret-dependent operations AND no secret return values, + // we don't need a generic at all (purely plaintext function). + // But if there are secret return values (e.g., function directly returns + // its secret input), we still need a generic even with no operations. + if (secretOps.empty() && secretReturnValues.empty()) { + // Purely plaintext function - no generic needed return success(); } - // Create a new block where we will insert the new secret.generic and move - // the function ops into. - Block& opEntryBlock = op.getRegion().front(); + + // Phase 3: Collect inputs for the secret.generic block + // These are: (1) secret arguments, (2) plaintext values used by secret ops + SmallVector genericInputs; + SmallVector genericInputTypes; + + // Add all function arguments that are used by secret ops (or are secret) + for (unsigned i = 0; i < op.getNumArguments(); i++) { + genericInputs.push_back(op.getArgument(i)); + genericInputTypes.push_back(op.getArgument(i).getType()); + } + + // Collect plaintext-defined values that are used inside secret ops + SmallVector plaintextValuesUsedInGeneric; + for (Operation* secretOp : secretOps) { + for (Value operand : secretOp->getOperands()) { + // If the operand is from outside the secretOps set (i.e., plaintext) + if (!secretValues.contains(operand)) { + Operation* defOp = operand.getDefiningOp(); + // It's a plaintext value defined by a non-secret op in this function + if (defOp && !secretOps.contains(defOp) && + defOp->getParentRegion() == &op.getRegion()) { + if (!llvm::is_contained(plaintextValuesUsedInGeneric, operand)) { + plaintextValuesUsedInGeneric.push_back(operand); + genericInputs.push_back(operand); + genericInputTypes.push_back(operand.getType()); + } + } + } + } + } + + // Phase 4: Build the secret.generic with only secret ops + SmallVector genericOutputTypes; + for (Value v : secretReturnValues) { + genericOutputTypes.push_back(secret::SecretType::get(v.getType())); + } + + // Create a new block for the rewritten function auto* newBlock = rewriter.createBlock( &opEntryBlock, opEntryBlock.getArgumentTypes(), SmallVector(opEntryBlock.getNumArguments(), op.getLoc())); rewriter.setInsertionPointToStart(newBlock); + + // Build mapping from old block args to new block args + IRMapping outerMapping; + for (unsigned i = 0; i < opEntryBlock.getNumArguments(); ++i) { + outerMapping.map(opEntryBlock.getArgument(i), newBlock->getArgument(i)); + } + + // Clone plaintext operations to the new block (before the generic) + for (Operation& bodyOp : opEntryBlock) { + if (&bodyOp == returnOp) continue; + if (!secretOps.contains(&bodyOp)) { + Operation* clonedOp = rewriter.clone(bodyOp, outerMapping); + for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) { + outerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i)); + } + } + } + + // Update genericInputs to use the new block's values + SmallVector mappedGenericInputs; + for (Value v : genericInputs) { + mappedGenericInputs.push_back(outerMapping.lookupOrDefault(v)); + } + + // Now create the secret.generic auto newGeneric = secret::GenericOp::create( - rewriter, op.getLoc(), op.getArguments(), newOutputs, + rewriter, op.getLoc(), mappedGenericInputs, genericOutputTypes, [&](OpBuilder& b, Location loc, ValueRange blockArguments) { - // Map the input values to the block arguments. - IRMapping mp; - for (unsigned i = 0; i < blockArguments.size(); ++i) { - mp.map(opEntryBlock.getArgument(i), blockArguments[i]); + // Map inputs to block arguments + IRMapping innerMapping; + for (unsigned i = 0; i < genericInputs.size(); ++i) { + innerMapping.map(genericInputs[i], blockArguments[i]); + } + + // Clone only secret operations into the generic + for (Operation& bodyOp : opEntryBlock) { + if (&bodyOp == returnOp) continue; + if (secretOps.contains(&bodyOp)) { + Operation* clonedOp = b.clone(bodyOp, innerMapping); + for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) { + innerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i)); + } + } } - auto* returnOp = opEntryBlock.getTerminator(); - secret::YieldOp::create(b, loc, - llvm::to_vector(llvm::map_range( - returnOp->getOperands(), [&](Value v) { - return mp.lookupOrDefault(v); - }))); - returnOp->erase(); + // Yield only the secret return values + SmallVector yieldValues; + for (Value v : secretReturnValues) { + yieldValues.push_back(innerMapping.lookupOrDefault(v)); + } + secret::YieldOp::create(b, loc, yieldValues); }); - Block& genericBlock = newGeneric.getRegion().front(); - rewriter.inlineBlockBefore(&opEntryBlock, - &genericBlock.getOperations().back(), - genericBlock.getArguments()); - func::ReturnOp::create(rewriter, op.getLoc(), newGeneric.getResults()); + // Build the final return values in the correct order + SmallVector finalReturnValues(op.getNumResults()); + unsigned secretResultIdx = 0; + for (unsigned idx : secretReturnIndices) { + finalReturnValues[idx] = newGeneric.getResult(secretResultIdx++); + } + for (unsigned idx : plaintextReturnIndices) { + Value returnVal = returnOp->getOperand(idx); + finalReturnValues[idx] = outerMapping.lookupOrDefault(returnVal); + } + + func::ReturnOp::create(rewriter, op.getLoc(), finalReturnValues); + + // Erase the old block + rewriter.eraseBlock(&opEntryBlock); return success(); } + + private: + DataFlowSolver* solver; }; struct ConvertFuncCall : public OpRewritePattern { @@ -159,20 +311,28 @@ struct WrapGeneric : impl::WrapGenericBase { using WrapGenericBase::WrapGenericBase; void detectSecretGeneric() { - bool hasSecretGeneric = false; - getOperation().walk([&](secret::GenericOp op) { hasSecretGeneric = true; }); - if (!hasSecretGeneric) { - getOperation().emitWarning( - "No secret found in the module. Did you forget to annotate " - "{secret.secret} to the function arguments?"); - } + // Note: Since we now correctly handle functions that return only + // plaintext values (which don't get a secret.generic), we should not + // warn about missing secret.generic ops. The warning was intended + // for the case where users forgot to annotate secret inputs, but that + // is already caught by the hasSecrets check in WrapWithGeneric. } void runOnOperation() override { MLIRContext* context = &getContext(); + // Run SecretnessAnalysis to determine which values depend on secrets + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run SecretnessAnalysis.\n"; + signalPassFailure(); + return; + } + mlir::RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context, &solver); (void)walkAndApplyPatterns(getOperation(), std::move(patterns)); // func.call should be converted after callee func type updated diff --git a/tests/Dialect/Secret/Transforms/wrap_generic/wrap_generic.mlir b/tests/Dialect/Secret/Transforms/wrap_generic/wrap_generic.mlir index 41fed0d74a..faf89b36bc 100644 --- a/tests/Dialect/Secret/Transforms/wrap_generic/wrap_generic.mlir +++ b/tests/Dialect/Secret/Transforms/wrap_generic/wrap_generic.mlir @@ -84,3 +84,20 @@ module { return %alloc : memref<1x80xi8> } } + +// ----- + +// Regression test for issue #2553: plaintext constant should not become secret +// When a function only returns values that don't depend on secrets, +// no secret.generic should be created. +module { + // CHECK: @plaintext_output(%arg0: !secret.secret) -> i8 + func.func @plaintext_output(%x: i32 {secret.secret}) -> i8 { + // The constant does not depend on the secret input + // CHECK-NOT: secret.generic + %0 = arith.constant 42 : i8 + // CHECK: return %{{.*}} : i8 + func.return %0 : i8 + } +} + From 3f34785cfe82116995dcf681cc2e7228b57ce332 Mon Sep 17 00:00:00 2001 From: Panther Date: Wed, 28 Jan 2026 10:15:10 +0530 Subject: [PATCH 2/2] refactor: simplify WrapGeneric to use selective output types only Per maintainer feedback, this change: - Keeps the original 'wrap entire block' approach - Only changes output type selection based on SecretnessAnalysis - Removes the complex op partitioning (cloning only secret ops) - Fixes #2553 by not creating generic when no outputs depend on secrets The key insight from maintainers is that hoisting plaintext ops should be handled by HoistPlaintextOps pass, not here. --- lib/Transforms/Secretize/WrapGeneric.cpp | 195 +++++------------------ 1 file changed, 39 insertions(+), 156 deletions(-) diff --git a/lib/Transforms/Secretize/WrapGeneric.cpp b/lib/Transforms/Secretize/WrapGeneric.cpp index 04d25f2204..d77e9d6146 100644 --- a/lib/Transforms/Secretize/WrapGeneric.cpp +++ b/lib/Transforms/Secretize/WrapGeneric.cpp @@ -10,7 +10,6 @@ #include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/Block.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -62,12 +61,11 @@ struct WrapWithGeneric : public OpRewritePattern { } // Externally defined functions have no body - conservatively wrap all - // outputs + // outputs as secret if (op.isDeclaration()) { - SmallVector newOutputs; - for (Type resultType : op.getResultTypes()) { - newOutputs.push_back(secret::SecretType::get(resultType)); - } + auto newOutputs = llvm::to_vector<6>(llvm::map_range( + op.getResultTypes(), + [](Type t) -> Type { return secret::SecretType::get(t); })); rewriter.modifyOpInPlace(op, [&] { op.setFunctionType( FunctionType::get(getContext(), {newInputs}, {newOutputs})); @@ -75,54 +73,21 @@ struct WrapWithGeneric : public OpRewritePattern { return success(); } - // Phase 1: Identify which operations depend on secrets + // Use SecretnessAnalysis to determine which outputs depend on secrets Block& opEntryBlock = op.getRegion().front(); auto* returnOp = opEntryBlock.getTerminator(); - // Track which values are secret (including block arguments) - llvm::DenseSet secretValues; - for (unsigned i = 0; i < op.getNumArguments(); i++) { - if (isSecret(op.getArgument(i), solver)) { - secretValues.insert(op.getArgument(i)); - } - } - - // Track which operations are secret-dependent - llvm::DenseSet secretOps; - for (Operation& bodyOp : opEntryBlock) { - if (&bodyOp == returnOp) continue; - - // An operation is secret if any of its operands are secret - bool isSecretOp = llvm::any_of(bodyOp.getOperands(), [&](Value operand) { - return secretValues.contains(operand) || isSecret(operand, solver); - }); - - if (isSecretOp) { - secretOps.insert(&bodyOp); - // All results of a secret op become secret - for (Value result : bodyOp.getResults()) { - secretValues.insert(result); - } - } - } - - // Phase 2: Determine output types and which outputs need to be in generic + // Determine output types: only wrap in secret if the value depends on + // secrets SmallVector newOutputs; - SmallVector secretReturnValues; - SmallVector plaintextReturnValues; - SmallVector secretReturnIndices; - SmallVector plaintextReturnIndices; - + bool hasSecretOutputs = false; for (auto [i, resultType] : llvm::enumerate(op.getResultTypes())) { Value returnVal = returnOp->getOperand(i); - if (secretValues.contains(returnVal) || isSecret(returnVal, solver)) { + if (isSecret(returnVal, solver)) { newOutputs.push_back(secret::SecretType::get(resultType)); - secretReturnValues.push_back(returnVal); - secretReturnIndices.push_back(i); + hasSecretOutputs = true; } else { newOutputs.push_back(resultType); - plaintextReturnValues.push_back(returnVal); - plaintextReturnIndices.push_back(i); } } @@ -132,126 +97,43 @@ struct WrapWithGeneric : public OpRewritePattern { FunctionType::get(getContext(), {newInputs}, {newOutputs})); }); - // If there are no secret-dependent operations AND no secret return values, - // we don't need a generic at all (purely plaintext function). - // But if there are secret return values (e.g., function directly returns - // its secret input), we still need a generic even with no operations. - if (secretOps.empty() && secretReturnValues.empty()) { - // Purely plaintext function - no generic needed + // If no outputs depend on secrets, don't create a generic block. + // This fixes issue #2553: functions that return only plaintext values + // should not have their outputs wrapped in secret types. + if (!hasSecretOutputs) { return success(); } - // Phase 3: Collect inputs for the secret.generic block - // These are: (1) secret arguments, (2) plaintext values used by secret ops - SmallVector genericInputs; - SmallVector genericInputTypes; - - // Add all function arguments that are used by secret ops (or are secret) - for (unsigned i = 0; i < op.getNumArguments(); i++) { - genericInputs.push_back(op.getArgument(i)); - genericInputTypes.push_back(op.getArgument(i).getType()); - } - - // Collect plaintext-defined values that are used inside secret ops - SmallVector plaintextValuesUsedInGeneric; - for (Operation* secretOp : secretOps) { - for (Value operand : secretOp->getOperands()) { - // If the operand is from outside the secretOps set (i.e., plaintext) - if (!secretValues.contains(operand)) { - Operation* defOp = operand.getDefiningOp(); - // It's a plaintext value defined by a non-secret op in this function - if (defOp && !secretOps.contains(defOp) && - defOp->getParentRegion() == &op.getRegion()) { - if (!llvm::is_contained(plaintextValuesUsedInGeneric, operand)) { - plaintextValuesUsedInGeneric.push_back(operand); - genericInputs.push_back(operand); - genericInputTypes.push_back(operand.getType()); - } - } - } - } - } - - // Phase 4: Build the secret.generic with only secret ops - SmallVector genericOutputTypes; - for (Value v : secretReturnValues) { - genericOutputTypes.push_back(secret::SecretType::get(v.getType())); - } - - // Create a new block for the rewritten function + // Create a new block where we will insert the new secret.generic and move + // the function ops into. auto* newBlock = rewriter.createBlock( &opEntryBlock, opEntryBlock.getArgumentTypes(), SmallVector(opEntryBlock.getNumArguments(), op.getLoc())); rewriter.setInsertionPointToStart(newBlock); - - // Build mapping from old block args to new block args - IRMapping outerMapping; - for (unsigned i = 0; i < opEntryBlock.getNumArguments(); ++i) { - outerMapping.map(opEntryBlock.getArgument(i), newBlock->getArgument(i)); - } - - // Clone plaintext operations to the new block (before the generic) - for (Operation& bodyOp : opEntryBlock) { - if (&bodyOp == returnOp) continue; - if (!secretOps.contains(&bodyOp)) { - Operation* clonedOp = rewriter.clone(bodyOp, outerMapping); - for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) { - outerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i)); - } - } - } - - // Update genericInputs to use the new block's values - SmallVector mappedGenericInputs; - for (Value v : genericInputs) { - mappedGenericInputs.push_back(outerMapping.lookupOrDefault(v)); - } - - // Now create the secret.generic auto newGeneric = secret::GenericOp::create( - rewriter, op.getLoc(), mappedGenericInputs, genericOutputTypes, + rewriter, op.getLoc(), op.getArguments(), newOutputs, [&](OpBuilder& b, Location loc, ValueRange blockArguments) { - // Map inputs to block arguments - IRMapping innerMapping; - for (unsigned i = 0; i < genericInputs.size(); ++i) { - innerMapping.map(genericInputs[i], blockArguments[i]); + // Map the input values to the block arguments. + IRMapping mp; + for (unsigned i = 0; i < blockArguments.size(); ++i) { + mp.map(opEntryBlock.getArgument(i), blockArguments[i]); } - // Clone only secret operations into the generic - for (Operation& bodyOp : opEntryBlock) { - if (&bodyOp == returnOp) continue; - if (secretOps.contains(&bodyOp)) { - Operation* clonedOp = b.clone(bodyOp, innerMapping); - for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) { - innerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i)); - } - } - } - - // Yield only the secret return values - SmallVector yieldValues; - for (Value v : secretReturnValues) { - yieldValues.push_back(innerMapping.lookupOrDefault(v)); - } - secret::YieldOp::create(b, loc, yieldValues); + // Yield the return values, mapped through the IR mapping + secret::YieldOp::create(b, loc, + llvm::to_vector(llvm::map_range( + returnOp->getOperands(), [&](Value v) { + return mp.lookupOrDefault(v); + }))); + returnOp->erase(); }); - // Build the final return values in the correct order - SmallVector finalReturnValues(op.getNumResults()); - unsigned secretResultIdx = 0; - for (unsigned idx : secretReturnIndices) { - finalReturnValues[idx] = newGeneric.getResult(secretResultIdx++); - } - for (unsigned idx : plaintextReturnIndices) { - Value returnVal = returnOp->getOperand(idx); - finalReturnValues[idx] = outerMapping.lookupOrDefault(returnVal); - } - - func::ReturnOp::create(rewriter, op.getLoc(), finalReturnValues); - - // Erase the old block - rewriter.eraseBlock(&opEntryBlock); + Block& genericBlock = newGeneric.getRegion().front(); + rewriter.inlineBlockBefore(&opEntryBlock, + &genericBlock.getOperations().back(), + genericBlock.getArguments()); + func::ReturnOp::create(rewriter, op.getLoc(), newGeneric.getResults()); return success(); } @@ -311,11 +193,12 @@ struct WrapGeneric : impl::WrapGenericBase { using WrapGenericBase::WrapGenericBase; void detectSecretGeneric() { - // Note: Since we now correctly handle functions that return only - // plaintext values (which don't get a secret.generic), we should not - // warn about missing secret.generic ops. The warning was intended - // for the case where users forgot to annotate secret inputs, but that - // is already caught by the hasSecrets check in WrapWithGeneric. + bool hasSecretGeneric = false; + getOperation().walk([&](secret::GenericOp op) { hasSecretGeneric = true; }); + // Note: We no longer warn if no secret.generic is found, because + // functions that return only plaintext values intentionally don't + // create a secret.generic block. The hasSecrets check in WrapWithGeneric + // already catches the case where users forget to annotate secret inputs. } void runOnOperation() override {