Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions docs/source/builder/writing-kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,20 @@ The following sections enumerate all supported options for `build.toml`.
This option is provided for kernels that require functionality only
provided by newer CUDA toolkits.

### Framework sections

The the framework section specifies framework-specific settings. The name of
the section depends on the framework that is used. The currently-supported
frameworks are:

- AOT-compiled Torch kernel (`torch`).
- AOT-compiled TVM-FFI kernel (`tvm-ffi`).
- JIT-compiled or not-compiled Torch kernel (`torch-noarch`, experimental).

### `torch`

This section describes the Torch extension. In the future, there may be
similar sections for other frameworks. This section has the following
options:
This framework section is used for AOT-compiled Torch kernels, and has the
following options:

- `src` (required): a list of source files and headers.
- `pyext` (optional): the list of extensions for Python files. Default:
Expand All @@ -244,6 +253,35 @@ options:
For an example, see the [`relu-torch-stable-abi`](https://github.com/huggingface/kernels/tree/main/examples/kernels/relu-torch-stable-abi)
example kernel.

### `tvm-ffi`

This framework section is used for AOT-compiled TVM-FFI kernels.

- `src` (required): a list of source files and headers.
- `pyext` (optional): the list of extensions for Python files. Default:
`["py", "pyi"]`.
- `include` (optional): include directories relative to the project root.
Default: `[]`.

### `torch-noarch`

The `torch-noarch` section is used for JIT-compiled kernels or kernels that
do not require any compilation (e.g. a kernel that packages plain PyTorch
layers).

Normally, it is expected that this type of kernel runs on all CUDA capabilities
or ROCm architectures. However, for kernels that support only a limited range
of archs, the `cuda-capabilites` and `rocm-archs` options can be used to specify
the supported archs. These are then exported to `metadata.json` for consumption
by e.g. the Hugging Face Hub.

- `pyext` (optional): the list of extensions for Python files. Default:
`["py", "pyi"]`.
- `cuda-capabilities` (optional): a list of CUDA compute capabilities the
kernel supports (e.g. `["9.0", "10.0"]`).
- `rocm-archs` (optional): a list of ROCm architectures the kernel supports
(e.g. `["gfx942"]`).

### `kernel.<name>`

Specification of a kernel with the name `<name>`. Multiple `kernel.<name>`
Expand Down
32 changes: 5 additions & 27 deletions kernel-builder/src/pyproject/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::path::PathBuf;
use eyre::Result;
use itertools::Itertools;

use kernels_data::config::{Backend, General};
use kernels_data::metadata::{BackendInfo, Metadata};
use kernels_data::config::{Backend, Build};
use kernels_data::metadata::Metadata;

use crate::pyproject::ops_identifier::KernelIdentifier;
use crate::pyproject::FileSet;
Expand All @@ -21,37 +21,15 @@ pub fn write_compat_py(file_set: &mut FileSet) -> Result<()> {
}

pub fn write_metadata(
general: &General,
build: &Build,
kernel_id: &KernelIdentifier,
file_set: &mut FileSet,
) -> Result<()> {
for backend in &Backend::all() {
let writer = file_set.entry(format!("metadata-{backend}.json"));

let python_depends = general
.python_depends()
.map(|deps| Ok(deps?.0.to_owned()))
.chain(
general
.backend_python_depends(*backend)
.map(|deps| Ok(deps?.0.to_owned())),
)
.collect::<Result<Vec<_>>>()?;

let metadata = Metadata {
id: kernel_id.to_string_for_backend(*backend),
name: general.name.clone(),
version: general.version,
license: general.license.clone(),
upstream: general.upstream.clone(),
source: general.source.clone(),
python_depends,
backend: BackendInfo {
archs: None,
backend_type: *backend,
},
digest: None,
};
let metadata =
Metadata::for_backend(build, kernel_id.to_string_for_backend(*backend), *backend)?;

serde_json::to_writer_pretty(writer, &metadata)?;
}
Expand Down
11 changes: 6 additions & 5 deletions kernel-builder/src/pyproject/torch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ fn write_pyproject_toml(
let writer = file_set.entry("pyproject.toml");

// Common python dependencies (no backend-specific ones)
let python_dependencies = itertools::process_results(general.python_depends(), |iter| {
iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg)))
.join(", ")
})?;
let python_dependencies =
itertools::process_results(general.general_python_depends(), |iter| {
iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg)))
.join(", ")
})?;

// Collect backend-specific dependencies for all backends
let mut backend_dependencies = Vec::new();
Expand Down Expand Up @@ -271,7 +272,7 @@ pub fn write_torch_ext(

write_torch_registration_macros(&mut file_set)?;

write_metadata(&build.general, kernel_id, &mut file_set)?;
write_metadata(build, kernel_id, &mut file_set)?;

Ok(file_set)
}
11 changes: 6 additions & 5 deletions kernel-builder/src/pyproject/torch/noarch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub fn write_torch_ext_noarch(
&mut file_set,
)?;
write_setup_py(&mut file_set)?;
write_metadata(&build.general, kernel_id, &mut file_set)?;
write_metadata(build, kernel_id, &mut file_set)?;

Ok(file_set)
}
Expand Down Expand Up @@ -81,10 +81,11 @@ fn write_pyproject_toml(
});

// Common python dependencies (no backend-specific ones)
let python_dependencies = itertools::process_results(general.python_depends(), |iter| {
iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg)))
.join(", ")
})?;
let python_dependencies =
itertools::process_results(general.general_python_depends(), |iter| {
iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg)))
.join(", ")
})?;

// Collect backend-specific dependencies for all backends
let mut backend_dependencies = Vec::new();
Expand Down
11 changes: 6 additions & 5 deletions kernel-builder/src/pyproject/tvm_ffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub fn write_tvm_ffi_ext(

write_pyproject_toml(env, &build.general, &mut file_set)?;

write_metadata(&build.general, kernel_id, &mut file_set)?;
write_metadata(build, kernel_id, &mut file_set)?;

Ok(file_set)
}
Expand Down Expand Up @@ -107,10 +107,11 @@ pub fn write_pyproject_toml(
let writer = file_set.entry("pyproject.toml");

// Common python dependencies (no backend-specific ones)
let python_dependencies = itertools::process_results(general.python_depends(), |iter| {
iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg)))
.join(", ")
})?;
let python_dependencies =
itertools::process_results(general.general_python_depends(), |iter| {
iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg)))
.join(", ")
})?;

// Collect backend-specific dependencies for all backends
let mut backend_dependencies = Vec::new();
Expand Down
28 changes: 27 additions & 1 deletion kernels-data/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ impl Framework {
_ => None,
}
}

pub(crate) fn precomputable_backend_archs(&self, backend: Backend) -> Option<Vec<String>> {
match self {
Framework::TorchNoarch(torch_noarch) => match backend {
Backend::Cuda => torch_noarch.cuda_capabilities.clone(),
Backend::Rocm => torch_noarch.rocm_archs.clone(),
_ => None,
},
_ => None,
}
}
}

impl Build {
Expand Down Expand Up @@ -99,7 +110,7 @@ pub struct General {
}

impl General {
pub fn python_depends(
pub fn general_python_depends(
&self,
) -> Box<dyn Iterator<Item = Result<(&str, &PythonDependency)>> + '_> {
let general_python_deps = match self.python_depends.as_ref() {
Expand Down Expand Up @@ -147,6 +158,16 @@ impl General {
}
}))
}

pub fn all_python_depends(&self, backend: Backend) -> Result<Vec<String>> {
self.general_python_depends()
.map(|deps| Ok(deps?.0.to_owned()))
.chain(
self.backend_python_depends(backend)
.map(|deps| Ok(deps?.0.to_owned())),
)
.collect::<Result<Vec<_>>>()
}
}

pub struct CudaGeneral {
Expand Down Expand Up @@ -204,6 +225,11 @@ impl Torch {

pub struct TorchNoarch {
pub pyext: Option<Vec<String>>,
/// CUDA capabilities to write into metadata.
pub cuda_capabilities: Option<Vec<String>>,

/// ROCM archs to write into metadata.
pub rocm_archs: Option<Vec<String>>,
}

impl TorchNoarch {
Expand Down
6 changes: 5 additions & 1 deletion kernels-data/src/config/v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ impl TryFrom<Build> for super::Build {
let framework = match build.framework {
Some(Framework::Torch(torch)) => super::Framework::Torch(torch.into()),
Some(Framework::TvmFfi(tvm_ffi)) => super::Framework::TvmFfi(tvm_ffi.into()),
None => super::Framework::TorchNoarch(super::TorchNoarch { pyext: None }),
None => super::Framework::TorchNoarch(super::TorchNoarch {
pyext: None,
cuda_capabilities: None,
rocm_archs: None,
}),
};

Ok(Self {
Expand Down
10 changes: 10 additions & 0 deletions kernels-data/src/config/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ pub struct Torch {
#[serde(deny_unknown_fields)]
pub struct TorchNoarch {
pub pyext: Option<Vec<String>>,

#[serde(default)]
pub cuda_capabilities: Option<Vec<String>>,

#[serde(default)]
pub rocm_archs: Option<Vec<String>>,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
Expand Down Expand Up @@ -263,6 +269,8 @@ impl From<TorchNoarch> for super::TorchNoarch {
fn from(torch_noarch: TorchNoarch) -> Self {
Self {
pyext: torch_noarch.pyext,
cuda_capabilities: torch_noarch.cuda_capabilities,
rocm_archs: torch_noarch.rocm_archs,
}
}
}
Expand Down Expand Up @@ -460,6 +468,8 @@ impl From<super::TorchNoarch> for TorchNoarch {
fn from(torch_noarch: super::TorchNoarch) -> Self {
Self {
pyext: torch_noarch.pyext,
cuda_capabilities: torch_noarch.cuda_capabilities,
rocm_archs: torch_noarch.rocm_archs,
}
}
}
Expand Down
Loading
Loading