From 711d41f9431b22e61ab8378e8a737ca010a661a8 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Sat, 6 Jun 2026 20:19:54 +0200 Subject: [PATCH 1/3] fix functions --- .../kernel_function_operation.jl | 53 ++++++++++++++++++- test/test_abstract_operations.jl | 20 +++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/AbstractOperations/kernel_function_operation.jl b/src/AbstractOperations/kernel_function_operation.jl index f5c261bb5c2..cddfbfe24c0 100644 --- a/src/AbstractOperations/kernel_function_operation.jl +++ b/src/AbstractOperations/kernel_function_operation.jl @@ -1,4 +1,4 @@ -using Oceananigans.Utils: shortsummary, construct_regionally, prettysummary +using Oceananigans.Utils: Utils, shortsummary, construct_regionally, prettysummary """ KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, arguments...) @@ -11,6 +11,13 @@ Construct a `KernelFunctionOperation` at location `(LX, LY, LZ)` on `grid` with kernel_function(i, j, k, grid, arguments...) ``` +If the location contains `Nothing`, `kernel_function` may also omit the indices of the +`Nothing` dimensions: for example, at `(Center, Center, Nothing)` it may be called with + +```julia +kernel_function(i, j, grid, arguments...) +``` + Note that `compute!(kfo::KernelFunctionOperation)` calls `compute!` on all `kfo.arguments`. Examples @@ -49,6 +56,20 @@ KernelFunctionOperation at (Face, Face, Center) ├── kernel_function: ζ₃ᶠᶠᶜ (generic function with 1 method) └── arguments: ("Field", "Field") ``` + +Construct a `KernelFunctionOperation` at a reduced location using a kernel function +that omits the index of the `Nothing` dimension: + +```jldoctest kfo +surface_kernel_function(i, j, grid) = i + j +surface_op = KernelFunctionOperation{Center, Center, Nothing}(surface_kernel_function, grid) + +# output +KernelFunctionOperation at (Center, Center, ⋅) +├── grid: 1×8×8 RectilinearGrid{Float64, Periodic, Periodic, Bounded} on CPU with 1×3×3 halo +├── kernel_function: surface_kernel_function (generic function with 1 method) +└── arguments: () +``` """ struct KernelFunctionOperation{LX, LY, LZ, G, T, K, D} <: AbstractOperation{LX, LY, LZ, G, T} kernel_function :: K @@ -65,9 +86,37 @@ end # Convenience outer constructor: splat arguments into a tuple. # T defaults to eltype(grid) via the inner constructor. function KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, arguments...) where {LX, LY, LZ} + kernel_function = possibly_reduced_kernel_function(kernel_function, (LX, LY, LZ), arguments) return KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, tuple(arguments...)) end +""" + ReducedKernelFunction{D, F} + +Wrap a kernel function defined at a reduced location, forwarding only the indices of the non-`Nothing` dimensions `D`, +so that `kernel_function(i, j, k, grid, args...)` calls, e.g. for `D = (1, 2)`, `kernel_function(i, j, grid, args...)`. +""" +struct ReducedKernelFunction{Dims, F} + kernel_function :: F + ReducedKernelFunction{Dims}(kernel_function::F) where {Dims, F} = new{Dims, F}(kernel_function) +end + +@inline (rkf::ReducedKernelFunction{Dims})(i, j, k, grid, arguments...) where Dims = rkf.kernel_function(map(d -> (i, j, k)[d], Dims)..., grid, arguments...) + +function possibly_reduced_kernel_function(kernel_function, location, arguments) + kept_dimensions = Tuple(d for d in 1:3 if location[d] !== Nothing) + if length(kept_dimensions) == 3 + return kernel_function + end + reduced_parameter_count = length(kept_dimensions) + 1 + length(arguments) + has_reduced_method = any(m -> !m.isva && m.nargs - 1 == reduced_parameter_count, methods(kernel_function)) + return has_reduced_method ? ReducedKernelFunction{kept_dimensions}(kernel_function) : kernel_function +end + +Adapt.adapt_structure(to, rkf::ReducedKernelFunction{Dims}) where Dims = ReducedKernelFunction{Dims}(Adapt.adapt(to, rkf.kernel_function)) + +Utils.prettysummary(rkf::ReducedKernelFunction) = prettysummary(rkf.kernel_function) + @inline Base.getindex(κ::KernelFunctionOperation, i, j, k) = κ.kernel_function(i, j, k, κ.grid, κ.arguments...) indices(κ::KernelFunctionOperation) = construct_regionally(intersect_indices, location(κ), κ.arguments...) compute_at!(κ::KernelFunctionOperation, time) = Tuple(compute_at!(d, time) for d in κ.arguments) @@ -98,4 +147,4 @@ Base.show(io::IO, kfo::KernelFunctionOperation) = Tuple(shortsummary(a) for a in kfo.arguments[1:end-1])..., shortsummary(kfo.arguments[end]) end -) + ) diff --git a/test/test_abstract_operations.jl b/test/test_abstract_operations.jl index 39736d1ad30..c993a2dcfa0 100644 --- a/test/test_abstract_operations.jl +++ b/test/test_abstract_operations.jl @@ -182,6 +182,26 @@ for arch in archs less_trivial_kernel_function(i, j, k, grid, u, v) = @inbounds u[i, j, k] * ℑxyᶠᶜᵃ(i, j, k, grid, v) op = KernelFunctionOperation{Face, Center, Center}(less_trivial_kernel_function, grid, u, v) @test op isa KernelFunctionOperation + + two_index_kernel_function(i, j, grid) = i + j + op = KernelFunctionOperation{Center, Center, Nothing}(two_index_kernel_function, grid) + @test op isa KernelFunctionOperation + @test Array(interior(compute!(Field(op))))[:, :, 1] == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)] + + one_index_kernel_function(k, grid) = 2k + op = KernelFunctionOperation{Nothing, Nothing, Center}(one_index_kernel_function, grid) + @test Array(interior(compute!(Field(op))))[1, 1, :] == [2k for k in 1:size(grid, 3)] + + q = CenterField(grid) + set!(q, 2) + interior_pattern_kernel_function(i, k, grid, q) = @inbounds q[i, 1, k] * i * k + op = KernelFunctionOperation{Center, Nothing, Center}(interior_pattern_kernel_function, grid, q) + @test Array(interior(compute!(Field(op))))[:, 1, :] == [2 * i * k for i in 1:size(grid, 1), k in 1:size(grid, 3)] + + # Three-index kernel functions at reduced locations still work + three_index_kernel_function(i, j, k, grid) = i + j + op = KernelFunctionOperation{Center, Center, Nothing}(three_index_kernel_function, grid) + @test Array(interior(compute!(Field(op))))[:, :, 1] == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)] end @testset "Fidelity of simple binary operations" begin From a91ff3d55f0b40050982af7f1ad33899633c0d39 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 11 Jun 2026 08:21:42 +0200 Subject: [PATCH 2/3] simpler methodology --- .../kernel_function_operation.jl | 42 +++++++------------ test/test_abstract_operations.jl | 28 +++++++++++-- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/src/AbstractOperations/kernel_function_operation.jl b/src/AbstractOperations/kernel_function_operation.jl index cddfbfe24c0..d7a325e0404 100644 --- a/src/AbstractOperations/kernel_function_operation.jl +++ b/src/AbstractOperations/kernel_function_operation.jl @@ -1,4 +1,4 @@ -using Oceananigans.Utils: Utils, shortsummary, construct_regionally, prettysummary +using Oceananigans.Utils: shortsummary, construct_regionally, prettysummary """ KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, arguments...) @@ -18,6 +18,10 @@ If the location contains `Nothing`, `kernel_function` may also omit the indices kernel_function(i, j, grid, arguments...) ``` +The full three-index call is always preferred when it is applicable, so a function that +already accepts `(i, j, k, grid, arguments...)` keeps that behavior at every location; the +reduced call is used only when the full one is not applicable. + Note that `compute!(kfo::KernelFunctionOperation)` calls `compute!` on all `kfo.arguments`. Examples @@ -86,38 +90,24 @@ end # Convenience outer constructor: splat arguments into a tuple. # T defaults to eltype(grid) via the inner constructor. function KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, arguments...) where {LX, LY, LZ} - kernel_function = possibly_reduced_kernel_function(kernel_function, (LX, LY, LZ), arguments) return KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, tuple(arguments...)) end -""" - ReducedKernelFunction{D, F} - -Wrap a kernel function defined at a reduced location, forwarding only the indices of the non-`Nothing` dimensions `D`, -so that `kernel_function(i, j, k, grid, args...)` calls, e.g. for `D = (1, 2)`, `kernel_function(i, j, grid, args...)`. -""" -struct ReducedKernelFunction{Dims, F} - kernel_function :: F - ReducedKernelFunction{Dims}(kernel_function::F) where {Dims, F} = new{Dims, F}(kernel_function) -end - -@inline (rkf::ReducedKernelFunction{Dims})(i, j, k, grid, arguments...) where Dims = rkf.kernel_function(map(d -> (i, j, k)[d], Dims)..., grid, arguments...) - -function possibly_reduced_kernel_function(kernel_function, location, arguments) - kept_dimensions = Tuple(d for d in 1:3 if location[d] !== Nothing) - if length(kept_dimensions) == 3 - return kernel_function +# `getindex` calls the kernel function with the full `(i, j, k, grid, args...)` signature +# whenever that call is applicable. At a reduced location it otherwise drops the indices of +# the `Nothing` dimensions, calling e.g. `kernel_function(i, j, grid, args...)` +@inline function Base.getindex(κ::KernelFunctionOperation{LX, LY, LZ}, i, j, k) where {LX, LY, LZ} + if applicable(κ.kernel_function, i, j, k, κ.grid, κ.arguments...) + return κ.kernel_function(i, j, k, κ.grid, κ.arguments...) + else + reduced_indices = (kept_index(LX, i)..., kept_index(LY, j)..., kept_index(LZ, k)...) + return κ.kernel_function(reduced_indices..., κ.grid, κ.arguments...) end - reduced_parameter_count = length(kept_dimensions) + 1 + length(arguments) - has_reduced_method = any(m -> !m.isva && m.nargs - 1 == reduced_parameter_count, methods(kernel_function)) - return has_reduced_method ? ReducedKernelFunction{kept_dimensions}(kernel_function) : kernel_function end -Adapt.adapt_structure(to, rkf::ReducedKernelFunction{Dims}) where Dims = ReducedKernelFunction{Dims}(Adapt.adapt(to, rkf.kernel_function)) - -Utils.prettysummary(rkf::ReducedKernelFunction) = prettysummary(rkf.kernel_function) +@inline kept_index(::Type{Nothing}, index) = () +@inline kept_index(::Type, index) = (index,) -@inline Base.getindex(κ::KernelFunctionOperation, i, j, k) = κ.kernel_function(i, j, k, κ.grid, κ.arguments...) indices(κ::KernelFunctionOperation) = construct_regionally(intersect_indices, location(κ), κ.arguments...) compute_at!(κ::KernelFunctionOperation, time) = Tuple(compute_at!(d, time) for d in κ.arguments) diff --git a/test/test_abstract_operations.jl b/test/test_abstract_operations.jl index c993a2dcfa0..04b9ee60aaa 100644 --- a/test/test_abstract_operations.jl +++ b/test/test_abstract_operations.jl @@ -186,17 +186,39 @@ for arch in archs two_index_kernel_function(i, j, grid) = i + j op = KernelFunctionOperation{Center, Center, Nothing}(two_index_kernel_function, grid) @test op isa KernelFunctionOperation - @test Array(interior(compute!(Field(op))))[:, :, 1] == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)] + @test Array(interior(Field(op), :, :, 1)) == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)] one_index_kernel_function(k, grid) = 2k op = KernelFunctionOperation{Nothing, Nothing, Center}(one_index_kernel_function, grid) - @test Array(interior(compute!(Field(op))))[1, 1, :] == [2k for k in 1:size(grid, 3)] + @test Array(interior(Field(op), 1, 1, :)) == [2k for k in 1:size(grid, 3)] q = CenterField(grid) set!(q, 2) interior_pattern_kernel_function(i, k, grid, q) = @inbounds q[i, 1, k] * i * k op = KernelFunctionOperation{Center, Nothing, Center}(interior_pattern_kernel_function, grid, q) - @test Array(interior(compute!(Field(op))))[:, 1, :] == [2 * i * k for i in 1:size(grid, 1), k in 1:size(grid, 3)] + @test Array(interior(Field(op), :, 1, :)) == [2 * i * k for i in 1:size(grid, 1), k in 1:size(grid, 3)] + + # Varargs kernel functions keep the full three-index convention... + varargs_kernel_function(arguments...) = arguments[1] + arguments[2] + arguments[3] + op = KernelFunctionOperation{Center, Center, Nothing}(varargs_kernel_function, grid) + @test Array(interior(Field(op), :, :, 1)) == [i + j + 1 for i in 1:size(grid, 1), j in 1:size(grid, 2)] + + # ... unless only the reduced call is applicable (here the typed grid argument rejects `grid ← k`) + reduced_varargs_kernel_function(i, j, grid::Oceananigans.Grids.AbstractGrid, arguments...) = i * j + length(arguments) + op = KernelFunctionOperation{Center, Center, Nothing}(reduced_varargs_kernel_function, grid) + @test Array(interior(Field(op), :, :, 1)) == [i * j for i in 1:size(grid, 1), j in 1:size(grid, 2)] + + # When the full three-index call is applicable it is preferred, even if a + # reduced-arity method also exists (heavily overloaded operators rely on this) + dual_kernel_function(i, j, grid) = i + j + dual_kernel_function(i, j, k, grid) = -7 + op = KernelFunctionOperation{Center, Center, Nothing}(dual_kernel_function, grid) + @test Array(interior(Field(op), :, :, 1)) == fill(-7, size(grid, 1), size(grid, 2)) + + # Spacing operators (e.g. Δx) have reduced-arity helper methods that must not be + # mistaken for the reduced form at reduced locations + @test Array(interior(Field(xspacings(grid, Center(), Center(), Center())), :, 1, 1)) == + [Oceananigans.Operators.Δx(i, 1, 1, grid, Center(), Center(), Center()) for i in 1:size(grid, 1)] # Three-index kernel functions at reduced locations still work three_index_kernel_function(i, j, k, grid) = i + j From c6f75ef3a1996c9632360d95c92771949c283fb7 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Thu, 11 Jun 2026 10:32:39 +0200 Subject: [PATCH 3/3] remove whitespace --- src/AbstractOperations/kernel_function_operation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AbstractOperations/kernel_function_operation.jl b/src/AbstractOperations/kernel_function_operation.jl index d7a325e0404..0a9aade71de 100644 --- a/src/AbstractOperations/kernel_function_operation.jl +++ b/src/AbstractOperations/kernel_function_operation.jl @@ -95,7 +95,7 @@ end # `getindex` calls the kernel function with the full `(i, j, k, grid, args...)` signature # whenever that call is applicable. At a reduced location it otherwise drops the indices of -# the `Nothing` dimensions, calling e.g. `kernel_function(i, j, grid, args...)` +# the `Nothing` dimensions, calling e.g. `kernel_function(i, j, grid, args...)` @inline function Base.getindex(κ::KernelFunctionOperation{LX, LY, LZ}, i, j, k) where {LX, LY, LZ} if applicable(κ.kernel_function, i, j, k, κ.grid, κ.arguments...) return κ.kernel_function(i, j, k, κ.grid, κ.arguments...)