From 2cc1e55f5bf1949a88d11ae987da4e02b16b6618 Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Fri, 12 Jun 2026 11:58:40 -0700 Subject: [PATCH] Add option to remove singleton opt-barriers only. PiperOrigin-RevId: 931262255 --- xla/hlo/transforms/expanders/BUILD | 20 +++- .../optimization_barrier_expander.cc | 16 ++- .../expanders/optimization_barrier_expander.h | 11 +- .../optimization_barrier_expander_test.cc | 100 ++++++++++++++++++ 4 files changed, 141 insertions(+), 6 deletions(-) create mode 100644 xla/hlo/transforms/expanders/optimization_barrier_expander_test.cc diff --git a/xla/hlo/transforms/expanders/BUILD b/xla/hlo/transforms/expanders/BUILD index 6be143110e06b..f7bdfae5bacb6 100644 --- a/xla/hlo/transforms/expanders/BUILD +++ b/xla/hlo/transforms/expanders/BUILD @@ -40,7 +40,8 @@ cc_library( srcs = ["optimization_barrier_expander.cc"], hdrs = ["optimization_barrier_expander.h"], deps = [ - ":op_expander_pass", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -49,6 +50,23 @@ cc_library( ], ) +xla_cc_test( + name = "optimization_barrier_expander_test", + srcs = ["optimization_barrier_expander_test.cc"], + deps = [ + ":optimization_barrier_expander", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:verified_hlo_module", + "@com_google_absl//absl/log", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "comparison_expander", srcs = ["comparison_expander.cc"], diff --git a/xla/hlo/transforms/expanders/optimization_barrier_expander.cc b/xla/hlo/transforms/expanders/optimization_barrier_expander.cc index 689a981f9f38d..bd290015bad8b 100644 --- a/xla/hlo/transforms/expanders/optimization_barrier_expander.cc +++ b/xla/hlo/transforms/expanders/optimization_barrier_expander.cc @@ -23,6 +23,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/tsl/platform/status_macros.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" namespace xla { @@ -34,10 +38,16 @@ absl::StatusOr OptimizationBarrierExpander::RunImpl( module->MakeNonfusionComputations(execution_threads)) { bool modified = false; for (HloInstruction* inst : computation->instructions()) { - if (inst->opcode() == HloOpcode::kOptimizationBarrier) { - barriers.push_back(inst); - modified = true; + if (inst->opcode() != HloOpcode::kOptimizationBarrier) { + continue; } + if (only_remove_singleton_opt_barriers_ && inst->operand_count() == 1 && + inst->operand(0)->opcode() == HloOpcode::kTuple && + inst->operand(0)->operand_count() > 1) { + continue; + } + barriers.push_back(inst); + modified = true; } if (modified && module->has_schedule()) { diff --git a/xla/hlo/transforms/expanders/optimization_barrier_expander.h b/xla/hlo/transforms/expanders/optimization_barrier_expander.h index 8592815c24bb6..06878b5264a08 100644 --- a/xla/hlo/transforms/expanders/optimization_barrier_expander.h +++ b/xla/hlo/transforms/expanders/optimization_barrier_expander.h @@ -19,7 +19,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { @@ -27,13 +28,19 @@ namespace xla { class OptimizationBarrierExpander : public HloModulePass { public: OptimizationBarrierExpander() = default; + explicit OptimizationBarrierExpander(bool only_remove_singleton_opt_barriers) + : only_remove_singleton_opt_barriers_( + only_remove_singleton_opt_barriers) {} - absl::string_view name() const override { return "cse_barrier_expander"; } + absl::string_view name() const override { return "opt-barrier-expander"; } protected: absl::StatusOr RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + bool only_remove_singleton_opt_barriers_ = false; }; } // namespace xla diff --git a/xla/hlo/transforms/expanders/optimization_barrier_expander_test.cc b/xla/hlo/transforms/expanders/optimization_barrier_expander_test.cc new file mode 100644 index 0000000000000..753afa84427c7 --- /dev/null +++ b/xla/hlo/transforms/expanders/optimization_barrier_expander_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/expanders/optimization_barrier_expander.h" + +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/shape.h" + +namespace xla { +namespace { + +class HloBarrierInstruction : public HloInstruction { + public: + HloBarrierInstruction(const Shape& shape, + absl::Span operands) + : HloInstruction(HloOpcode::kOptimizationBarrier, shape) { + for (HloInstruction* operand : operands) { + AppendOperand(operand); + } + } +}; + +class OptimizationBarrierExpanderTest : public HloHardwareIndependentTestBase { +}; + +TEST_F(OptimizationBarrierExpanderTest, RemovesOptimizationBarrier) { + const char* hlo = R"( +HloModule module + +ENTRY main { + param0 = f32[10] parameter(0) + add0 = f32[10] add(param0, param0) + add1 = f32[10] add(param0, add0) + tuple = (f32[10], f32[10]) tuple(add0, add1) + barrier = (f32[10], f32[10]) opt-barrier(tuple) + gte = f32[10] get-tuple-element(barrier), index=0 + ROOT root = f32[10] add(gte, param0) +} +)"; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + OptimizationBarrierExpander expander; + ASSERT_OK_AND_ASSIGN(bool changed, expander.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE( + FindInstructions(module.get(), HloOpcode::kOptimizationBarrier).empty()); + VLOG(1) << module->ToString(); +} + +TEST_F(OptimizationBarrierExpanderTest, RemovesOnlySingularOptBarrier) { + const char* hlo = R"( +HloModule module + +ENTRY main { + param0 = f32[10] parameter(0) + param1 = f32[10] parameter(1) + add0 = f32[10] add(param0, param1) + barrier = f32[10] opt-barrier(add0) + ROOT add1 = f32[10] add(barrier, param0) +} +)"; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + OptimizationBarrierExpander expander( + /*only_remove_singleton_opt_barriers=*/true); + ASSERT_OK_AND_ASSIGN(bool changed, expander.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE( + FindInstructions(module.get(), HloOpcode::kOptimizationBarrier).empty()); + VLOG(1) << module->ToString(); +} + +} // namespace +} // namespace xla