diff --git a/docs/source/builder/writing-kernels.md b/docs/source/builder/writing-kernels.md index e030cb83..d9fb46d3 100644 --- a/docs/source/builder/writing-kernels.md +++ b/docs/source/builder/writing-kernels.md @@ -222,11 +222,20 @@ The following sections enumerate all supported options for `build.toml`. This option is provided for kernels that require functionality only provided by newer CUDA toolkits. +### Framework sections + +The the framework section specifies framework-specific settings. The name of +the section depends on the framework that is used. The currently-supported +frameworks are: + +- AOT-compiled Torch kernel (`torch`). +- AOT-compiled TVM-FFI kernel (`tvm-ffi`). +- JIT-compiled or not-compiled Torch kernel (`torch-noarch`, experimental). + ### `torch` -This section describes the Torch extension. In the future, there may be -similar sections for other frameworks. This section has the following -options: +This framework section is used for AOT-compiled Torch kernels, and has the +following options: - `src` (required): a list of source files and headers. - `pyext` (optional): the list of extensions for Python files. Default: @@ -244,6 +253,35 @@ options: For an example, see the [`relu-torch-stable-abi`](https://github.com/huggingface/kernels/tree/main/examples/kernels/relu-torch-stable-abi) example kernel. +### `tvm-ffi` + +This framework section is used for AOT-compiled TVM-FFI kernels. + +- `src` (required): a list of source files and headers. +- `pyext` (optional): the list of extensions for Python files. Default: + `["py", "pyi"]`. +- `include` (optional): include directories relative to the project root. + Default: `[]`. + +### `torch-noarch` + +The `torch-noarch` section is used for JIT-compiled kernels or kernels that +do not require any compilation (e.g. a kernel that packages plain PyTorch +layers). + +Normally, it is expected that this type of kernel runs on all CUDA capabilities +or ROCm architectures. However, for kernels that support only a limited range +of archs, the `cuda-capabilites` and `rocm-archs` options can be used to specify +the supported archs. These are then exported to `metadata.json` for consumption +by e.g. the Hugging Face Hub. + +- `pyext` (optional): the list of extensions for Python files. Default: + `["py", "pyi"]`. +- `cuda-capabilities` (optional): a list of CUDA compute capabilities the + kernel supports (e.g. `["9.0", "10.0"]`). +- `rocm-archs` (optional): a list of ROCm architectures the kernel supports + (e.g. `["gfx942"]`). + ### `kernel.` Specification of a kernel with the name ``. Multiple `kernel.` diff --git a/kernel-builder/src/pyproject/common.rs b/kernel-builder/src/pyproject/common.rs index c03ddb2e..5443f0b5 100644 --- a/kernel-builder/src/pyproject/common.rs +++ b/kernel-builder/src/pyproject/common.rs @@ -3,8 +3,8 @@ use std::path::PathBuf; use eyre::Result; use itertools::Itertools; -use kernels_data::config::{Backend, General}; -use kernels_data::metadata::{BackendInfo, Metadata}; +use kernels_data::config::{Backend, Build}; +use kernels_data::metadata::Metadata; use crate::pyproject::ops_identifier::KernelIdentifier; use crate::pyproject::FileSet; @@ -21,37 +21,15 @@ pub fn write_compat_py(file_set: &mut FileSet) -> Result<()> { } pub fn write_metadata( - general: &General, + build: &Build, kernel_id: &KernelIdentifier, file_set: &mut FileSet, ) -> Result<()> { for backend in &Backend::all() { let writer = file_set.entry(format!("metadata-{backend}.json")); - let python_depends = general - .python_depends() - .map(|deps| Ok(deps?.0.to_owned())) - .chain( - general - .backend_python_depends(*backend) - .map(|deps| Ok(deps?.0.to_owned())), - ) - .collect::>>()?; - - let metadata = Metadata { - id: kernel_id.to_string_for_backend(*backend), - name: general.name.clone(), - version: general.version, - license: general.license.clone(), - upstream: general.upstream.clone(), - source: general.source.clone(), - python_depends, - backend: BackendInfo { - archs: None, - backend_type: *backend, - }, - digest: None, - }; + let metadata = + Metadata::for_backend(build, kernel_id.to_string_for_backend(*backend), *backend)?; serde_json::to_writer_pretty(writer, &metadata)?; } diff --git a/kernel-builder/src/pyproject/torch/mod.rs b/kernel-builder/src/pyproject/torch/mod.rs index dae4265a..23e05a1a 100644 --- a/kernel-builder/src/pyproject/torch/mod.rs +++ b/kernel-builder/src/pyproject/torch/mod.rs @@ -69,10 +69,11 @@ fn write_pyproject_toml( let writer = file_set.entry("pyproject.toml"); // Common python dependencies (no backend-specific ones) - let python_dependencies = itertools::process_results(general.python_depends(), |iter| { - iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg))) - .join(", ") - })?; + let python_dependencies = + itertools::process_results(general.general_python_depends(), |iter| { + iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg))) + .join(", ") + })?; // Collect backend-specific dependencies for all backends let mut backend_dependencies = Vec::new(); @@ -271,7 +272,7 @@ pub fn write_torch_ext( write_torch_registration_macros(&mut file_set)?; - write_metadata(&build.general, kernel_id, &mut file_set)?; + write_metadata(build, kernel_id, &mut file_set)?; Ok(file_set) } diff --git a/kernel-builder/src/pyproject/torch/noarch.rs b/kernel-builder/src/pyproject/torch/noarch.rs index db3d9661..1fbe8d52 100644 --- a/kernel-builder/src/pyproject/torch/noarch.rs +++ b/kernel-builder/src/pyproject/torch/noarch.rs @@ -34,7 +34,7 @@ pub fn write_torch_ext_noarch( &mut file_set, )?; write_setup_py(&mut file_set)?; - write_metadata(&build.general, kernel_id, &mut file_set)?; + write_metadata(build, kernel_id, &mut file_set)?; Ok(file_set) } @@ -81,10 +81,11 @@ fn write_pyproject_toml( }); // Common python dependencies (no backend-specific ones) - let python_dependencies = itertools::process_results(general.python_depends(), |iter| { - iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg))) - .join(", ") - })?; + let python_dependencies = + itertools::process_results(general.general_python_depends(), |iter| { + iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg))) + .join(", ") + })?; // Collect backend-specific dependencies for all backends let mut backend_dependencies = Vec::new(); diff --git a/kernel-builder/src/pyproject/tvm_ffi/mod.rs b/kernel-builder/src/pyproject/tvm_ffi/mod.rs index 4818303c..ccb4e52a 100644 --- a/kernel-builder/src/pyproject/tvm_ffi/mod.rs +++ b/kernel-builder/src/pyproject/tvm_ffi/mod.rs @@ -65,7 +65,7 @@ pub fn write_tvm_ffi_ext( write_pyproject_toml(env, &build.general, &mut file_set)?; - write_metadata(&build.general, kernel_id, &mut file_set)?; + write_metadata(build, kernel_id, &mut file_set)?; Ok(file_set) } @@ -107,10 +107,11 @@ pub fn write_pyproject_toml( let writer = file_set.entry("pyproject.toml"); // Common python dependencies (no backend-specific ones) - let python_dependencies = itertools::process_results(general.python_depends(), |iter| { - iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg))) - .join(", ") - })?; + let python_dependencies = + itertools::process_results(general.general_python_depends(), |iter| { + iter.flat_map(|(_, deps)| deps.python.iter().map(|d| format!("\"{}\"", d.pkg))) + .join(", ") + })?; // Collect backend-specific dependencies for all backends let mut backend_dependencies = Vec::new(); diff --git a/kernels-data/src/config/mod.rs b/kernels-data/src/config/mod.rs index 5165ad6d..c4b346d1 100644 --- a/kernels-data/src/config/mod.rs +++ b/kernels-data/src/config/mod.rs @@ -58,6 +58,17 @@ impl Framework { _ => None, } } + + pub(crate) fn precomputable_backend_archs(&self, backend: Backend) -> Option> { + match self { + Framework::TorchNoarch(torch_noarch) => match backend { + Backend::Cuda => torch_noarch.cuda_capabilities.clone(), + Backend::Rocm => torch_noarch.rocm_archs.clone(), + _ => None, + }, + _ => None, + } + } } impl Build { @@ -99,7 +110,7 @@ pub struct General { } impl General { - pub fn python_depends( + pub fn general_python_depends( &self, ) -> Box> + '_> { let general_python_deps = match self.python_depends.as_ref() { @@ -147,6 +158,16 @@ impl General { } })) } + + pub fn all_python_depends(&self, backend: Backend) -> Result> { + self.general_python_depends() + .map(|deps| Ok(deps?.0.to_owned())) + .chain( + self.backend_python_depends(backend) + .map(|deps| Ok(deps?.0.to_owned())), + ) + .collect::>>() + } } pub struct CudaGeneral { @@ -204,6 +225,11 @@ impl Torch { pub struct TorchNoarch { pub pyext: Option>, + /// CUDA capabilities to write into metadata. + pub cuda_capabilities: Option>, + + /// ROCM archs to write into metadata. + pub rocm_archs: Option>, } impl TorchNoarch { diff --git a/kernels-data/src/config/v3.rs b/kernels-data/src/config/v3.rs index c4873bd9..1436d77f 100644 --- a/kernels-data/src/config/v3.rs +++ b/kernels-data/src/config/v3.rs @@ -171,7 +171,11 @@ impl TryFrom for super::Build { let framework = match build.framework { Some(Framework::Torch(torch)) => super::Framework::Torch(torch.into()), Some(Framework::TvmFfi(tvm_ffi)) => super::Framework::TvmFfi(tvm_ffi.into()), - None => super::Framework::TorchNoarch(super::TorchNoarch { pyext: None }), + None => super::Framework::TorchNoarch(super::TorchNoarch { + pyext: None, + cuda_capabilities: None, + rocm_archs: None, + }), }; Ok(Self { diff --git a/kernels-data/src/config/v4.rs b/kernels-data/src/config/v4.rs index 44279fed..b0d554a2 100644 --- a/kernels-data/src/config/v4.rs +++ b/kernels-data/src/config/v4.rs @@ -97,6 +97,12 @@ pub struct Torch { #[serde(deny_unknown_fields)] pub struct TorchNoarch { pub pyext: Option>, + + #[serde(default)] + pub cuda_capabilities: Option>, + + #[serde(default)] + pub rocm_archs: Option>, } #[derive(Debug, Deserialize, Clone, Serialize)] @@ -263,6 +269,8 @@ impl From for super::TorchNoarch { fn from(torch_noarch: TorchNoarch) -> Self { Self { pyext: torch_noarch.pyext, + cuda_capabilities: torch_noarch.cuda_capabilities, + rocm_archs: torch_noarch.rocm_archs, } } } @@ -460,6 +468,8 @@ impl From for TorchNoarch { fn from(torch_noarch: super::TorchNoarch) -> Self { Self { pyext: torch_noarch.pyext, + cuda_capabilities: torch_noarch.cuda_capabilities, + rocm_archs: torch_noarch.rocm_archs, } } } diff --git a/kernels-data/src/metadata.rs b/kernels-data/src/metadata.rs index 109c93fd..45f551d9 100644 --- a/kernels-data/src/metadata.rs +++ b/kernels-data/src/metadata.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use eyre::Result; use serde::{Deserialize, Serialize}; -use crate::config::{Backend, GitUrl, KernelName}; +use crate::config::{Backend, Build, GitUrl, KernelName}; use crate::digest::Digest; #[derive(Debug, Deserialize, Serialize)] @@ -34,6 +34,33 @@ pub struct Metadata { } impl Metadata { + /// Construct metadata for a specific backend. + /// + /// This constructor creates metadata for a specific backend from the + /// kernel build configuration and kernel identifier. + /// + /// Supported backend archs are only supported for Torch noarch, since + /// the archs need to be computed at build time for arch frameworks. + pub fn for_backend(build: &Build, id: String, backend: Backend) -> Result { + let python_depends = build.general.all_python_depends(backend)?; + let archs = build.framework.precomputable_backend_archs(backend); + + Ok(Self { + id, + name: build.general.name.clone(), + version: build.general.version, + license: build.general.license.clone(), + upstream: build.general.upstream.clone(), + source: build.general.source.clone(), + python_depends, + backend: BackendInfo { + archs, + backend_type: backend, + }, + digest: None, + }) + } + /// Read the metadata from a JSON byte slice. pub fn from_bytes(bytes: &[u8]) -> Result { Ok(serde_json::from_slice(bytes)?) @@ -52,3 +79,95 @@ impl FromStr for Metadata { Ok(serde_json::from_str(s)?) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::config::{Backend, Build, Framework, General, KernelName, TorchNoarch, TvmFfi}; + + use super::Metadata; + + fn torch_noarch_build() -> Build { + Build { + general: General { + name: KernelName::new("test-kernel").unwrap(), + version: 1, + license: "apache-2.0".to_string(), + upstream: None, + source: None, + backends: vec![Backend::Cuda, Backend::Rocm, Backend::Cpu], + hub: None, + python_depends: None, + cuda: None, + neuron: None, + xpu: None, + }, + kernels: HashMap::new(), + framework: Framework::TorchNoarch(TorchNoarch { + pyext: None, + cuda_capabilities: Some(vec!["7.0".to_string(), "8.0".to_string()]), + rocm_archs: Some(vec!["gfx90a".to_string()]), + }), + } + } + + #[test] + fn cuda_archs_for_torch_noarch() { + let build = torch_noarch_build(); + let metadata = Metadata::for_backend(&build, "test-id".to_string(), Backend::Cuda).unwrap(); + + assert_eq!(metadata.backend.backend_type, Backend::Cuda); + assert_eq!( + metadata.backend.archs, + Some(vec!["7.0".to_string(), "8.0".to_string()]) + ); + } + + #[test] + fn rocm_archs_for_torch_noarch() { + let build = torch_noarch_build(); + let metadata = Metadata::for_backend(&build, "test-id".to_string(), Backend::Rocm).unwrap(); + + assert_eq!(metadata.backend.backend_type, Backend::Rocm); + assert_eq!(metadata.backend.archs, Some(vec!["gfx90a".to_string()])); + } + + #[test] + fn no_archs_for_cpu_with_torch_noarch() { + let build = torch_noarch_build(); + let metadata = Metadata::for_backend(&build, "test-id".to_string(), Backend::Cpu).unwrap(); + + assert_eq!(metadata.backend.backend_type, Backend::Cpu); + assert!(metadata.backend.archs.is_none()); + } + + #[test] + fn no_archs_for_arch_framework() { + let build = Build { + general: General { + name: KernelName::new("test-kernel").unwrap(), + version: 1, + license: "apache-2.0".to_string(), + upstream: None, + source: None, + backends: vec![Backend::Cuda], + hub: None, + python_depends: None, + cuda: None, + neuron: None, + xpu: None, + }, + kernels: HashMap::new(), + framework: Framework::TvmFfi(TvmFfi { + include: None, + pyext: None, + src: vec![], + }), + }; + let metadata = Metadata::for_backend(&build, "test-id".to_string(), Backend::Cuda).unwrap(); + + assert_eq!(metadata.backend.backend_type, Backend::Cuda); + assert!(metadata.backend.archs.is_none()); + } +}