Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include <functional>
#include <memory>
#include <string>
#include <variant>
#include <vector>

#include "absl/base/call_once.h"
Expand All @@ -30,7 +29,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "xla/tsl/platform/status_macros.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/Analysis/CGSCCPassManager.h"
#include "llvm/Analysis/LazyCallGraph.h"
Expand Down Expand Up @@ -59,6 +57,7 @@ limitations under the License.
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/IPO/Internalize.h"
#include "llvm/Transforms/Scalar.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h"
#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
Expand Down Expand Up @@ -239,7 +238,13 @@ std::vector<std::string> GetNVPTXBackendOptions(
return backend_llvm_opts;
}

std::string GetSmName(se::CudaComputeCapability compute_capability) {
constexpr se::CudaComputeCapability kSupportedVersions[] = {
{12, 1}, {12, 0}, {11, 0}, {10, 3}, {10, 0}, {9, 0}, {8, 9}, {8, 7},
{8, 6}, {8, 0}, {7, 5}, {7, 2}, {7, 0}, {6, 2}, {6, 1}, {6, 0},
{5, 3}, {5, 2}, {5, 0}, {3, 7}, {3, 5}, {3, 2}, {3, 0}};

se::CudaComputeCapability ResolveSupportedComputeCapability(
se::CudaComputeCapability compute_capability) {
using CudaComputeCapabilities =
se::CudaComputeCapability::CudaComputeCapabilities;

Expand All @@ -248,10 +253,6 @@ std::string GetSmName(se::CudaComputeCapability compute_capability) {
se::CudaComputeCapability::FeatureExtension::kNone;
// If the current compute capability isn't known, fallback to the
// most recent version before it.
constexpr stream_executor::CudaComputeCapability kSupportedVersions[] = {
{12, 1}, {12, 0}, {11, 0}, {10, 3}, {10, 0}, {9, 0}, {8, 9}, {8, 7},
{8, 6}, {8, 0}, {7, 5}, {7, 2}, {7, 0}, {6, 2}, {6, 1}, {6, 0},
{5, 3}, {5, 2}, {5, 0}, {3, 7}, {3, 5}, {3, 2}, {3, 0}};
// Initialize to the least supported version, which acts as a safe fallback
auto target_compute_capability =
kSupportedVersions[std::size(kSupportedVersions) - 1];
Expand Down Expand Up @@ -284,6 +285,13 @@ std::string GetSmName(se::CudaComputeCapability compute_capability) {
se::CudaComputeCapability::FeatureExtension::kFamilyCompatibleFeatures;
}

return target_compute_capability;
}

std::string GetSmName(se::CudaComputeCapability compute_capability) {
se::CudaComputeCapability target_compute_capability =
ResolveSupportedComputeCapability(compute_capability);

// If the current CC isn't supported by LLVM and it is newer then
// the max supported LLVM version, do not warn about it. The end
// user can't do anything about this. E.g., PTX compiled for SM75 will
Expand Down
11 changes: 9 additions & 2 deletions xla/service/gpu/llvm_gpu_backend/nvptx_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "llvm/IR/Module.h"
#include "llvm/Target/TargetMachine.h"
#include "xla/service/gpu/llvm_gpu_backend/ptx_version_util.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/semantic_version.h"
#include "xla/xla.pb.h"

namespace xla::gpu::nvptx {

// Resolves the compute capability that XLA actually compiles for given the
// compute capability of the target device. If the device's compute capability
// is not directly supported by the bundled LLVM/ptxas, this returns the most
// advanced supported compute capability that the device can run, potentially
// with the family ("f") feature extension enabled.
stream_executor::CudaComputeCapability ResolveSupportedComputeCapability(
stream_executor::CudaComputeCapability compute_capability);

// Gets the GPU name as it's known to LLVM for a given compute
// capability. If we see an unrecognized compute capability, we
// return the highest one that is known and below the selected device.
Expand Down
21 changes: 21 additions & 0 deletions xla/service/gpu/llvm_gpu_backend/nvptx_backend_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ TEST(UtilsTest, TestGetSmName) {
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{13, 0}), "sm_121");
}

TEST(UtilsTest, UnknownCapabilityFallsBackToFamilyCompatible) {
using FeatureExtension = se::CudaComputeCapability::FeatureExtension;
// Directly supported compute capabilities keep their feature extension.
EXPECT_EQ(nvptx::ResolveSupportedComputeCapability(se::CudaComputeCapability{
10, 0, FeatureExtension::kAcceleratedFeatures}),
(se::CudaComputeCapability{
10, 0, FeatureExtension::kAcceleratedFeatures}));
// An unknown compute capability within a known major version falls back to
// the latest supported minor version with the family compatible extension.
// This mirrors a yet-unreleased device (e.g. sm_1099a) where ptxas only knows
// about sm_103f.
EXPECT_EQ(nvptx::ResolveSupportedComputeCapability(se::CudaComputeCapability{
10, 99, FeatureExtension::kAcceleratedFeatures}),
(se::CudaComputeCapability{
10, 3, FeatureExtension::kFamilyCompatibleFeatures}));
// When no family-compatible extension is available, don't use any.
EXPECT_EQ(nvptx::ResolveSupportedComputeCapability(se::CudaComputeCapability{
9, 99, FeatureExtension::kAcceleratedFeatures}),
(se::CudaComputeCapability{9, 0, FeatureExtension::kNone}));
}

using VersionPair = std::pair<se::SemanticVersion, se::SemanticVersion>;
using PtxVersionFromCudaVersionTest = ::testing::TestWithParam<VersionPair>;

Expand Down
10 changes: 5 additions & 5 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ NVPTXCompiler::CompileTargetBinary(
se::cuda::CompilationOptions compilation_options =
PtxCompileOptionsFromDebugOptions(module_config.debug_options());

se::CudaComputeCapability cc =
*device_description.gpu_compute_capability().cuda_compute_capability();
se::CudaComputeCapability cc = nvptx::ResolveSupportedComputeCapability(
*device_description.gpu_compute_capability().cuda_compute_capability());

// This may print multiple lines per HLO compilation because of the
// parallelized compilation of LLVM modules.
Expand Down Expand Up @@ -620,8 +620,8 @@ absl::StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
return std::vector<uint8_t>{};
}

auto cc =
device_description.gpu_compute_capability().cuda_compute_capability();
se::CudaComputeCapability cc = nvptx::ResolveSupportedComputeCapability(
*device_description.gpu_compute_capability().cuda_compute_capability());

ASSIGN_OR_RETURN(const se::cuda::CompilationProvider* compilation_provider,
GetCompilationProvider(debug_options, stream_exec));
Expand All @@ -640,7 +640,7 @@ absl::StatusOr<std::vector<uint8_t>> NVPTXCompiler::LinkModules(
<< compilation_provider->name();
ASSIGN_OR_RETURN(
se::cuda::Assembly assembly,
compilation_provider->CompileAndLink(*cc, inputs, compilation_options));
compilation_provider->CompileAndLink(cc, inputs, compilation_options));

return std::move(assembly.cubin);
}
Expand Down
Loading