diff --git a/src/AbstractOperations/kernel_function_operation.jl b/src/AbstractOperations/kernel_function_operation.jl index f5c261bb5c2..0a9aade71de 100644 --- a/src/AbstractOperations/kernel_function_operation.jl +++ b/src/AbstractOperations/kernel_function_operation.jl @@ -11,6 +11,17 @@ Construct a `KernelFunctionOperation` at location `(LX, LY, LZ)` on `grid` with kernel_function(i, j, k, grid, arguments...) ``` +If the location contains `Nothing`, `kernel_function` may also omit the indices of the +`Nothing` dimensions: for example, at `(Center, Center, Nothing)` it may be called with + +```julia +kernel_function(i, j, grid, arguments...) +``` + +The full three-index call is always preferred when it is applicable, so a function that +already accepts `(i, j, k, grid, arguments...)` keeps that behavior at every location; the +reduced call is used only when the full one is not applicable. + Note that `compute!(kfo::KernelFunctionOperation)` calls `compute!` on all `kfo.arguments`. Examples @@ -49,6 +60,20 @@ KernelFunctionOperation at (Face, Face, Center) ├── kernel_function: ζ₃ᶠᶠᶜ (generic function with 1 method) └── arguments: ("Field", "Field") ``` + +Construct a `KernelFunctionOperation` at a reduced location using a kernel function +that omits the index of the `Nothing` dimension: + +```jldoctest kfo +surface_kernel_function(i, j, grid) = i + j +surface_op = KernelFunctionOperation{Center, Center, Nothing}(surface_kernel_function, grid) + +# output +KernelFunctionOperation at (Center, Center, ⋅) +├── grid: 1×8×8 RectilinearGrid{Float64, Periodic, Periodic, Bounded} on CPU with 1×3×3 halo +├── kernel_function: surface_kernel_function (generic function with 1 method) +└── arguments: () +``` """ struct KernelFunctionOperation{LX, LY, LZ, G, T, K, D} <: AbstractOperation{LX, LY, LZ, G, T} kernel_function :: K @@ -68,7 +93,21 @@ function KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, arguments... return KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, tuple(arguments...)) end -@inline Base.getindex(κ::KernelFunctionOperation, i, j, k) = κ.kernel_function(i, j, k, κ.grid, κ.arguments...) +# `getindex` calls the kernel function with the full `(i, j, k, grid, args...)` signature +# whenever that call is applicable. At a reduced location it otherwise drops the indices of +# the `Nothing` dimensions, calling e.g. `kernel_function(i, j, grid, args...)` +@inline function Base.getindex(κ::KernelFunctionOperation{LX, LY, LZ}, i, j, k) where {LX, LY, LZ} + if applicable(κ.kernel_function, i, j, k, κ.grid, κ.arguments...) + return κ.kernel_function(i, j, k, κ.grid, κ.arguments...) + else + reduced_indices = (kept_index(LX, i)..., kept_index(LY, j)..., kept_index(LZ, k)...) + return κ.kernel_function(reduced_indices..., κ.grid, κ.arguments...) + end +end + +@inline kept_index(::Type{Nothing}, index) = () +@inline kept_index(::Type, index) = (index,) + indices(κ::KernelFunctionOperation) = construct_regionally(intersect_indices, location(κ), κ.arguments...) compute_at!(κ::KernelFunctionOperation, time) = Tuple(compute_at!(d, time) for d in κ.arguments) @@ -98,4 +137,4 @@ Base.show(io::IO, kfo::KernelFunctionOperation) = Tuple(shortsummary(a) for a in kfo.arguments[1:end-1])..., shortsummary(kfo.arguments[end]) end -) + ) diff --git a/test/test_abstract_operations.jl b/test/test_abstract_operations.jl index 39736d1ad30..04b9ee60aaa 100644 --- a/test/test_abstract_operations.jl +++ b/test/test_abstract_operations.jl @@ -182,6 +182,48 @@ for arch in archs less_trivial_kernel_function(i, j, k, grid, u, v) = @inbounds u[i, j, k] * ℑxyᶠᶜᵃ(i, j, k, grid, v) op = KernelFunctionOperation{Face, Center, Center}(less_trivial_kernel_function, grid, u, v) @test op isa KernelFunctionOperation + + two_index_kernel_function(i, j, grid) = i + j + op = KernelFunctionOperation{Center, Center, Nothing}(two_index_kernel_function, grid) + @test op isa KernelFunctionOperation + @test Array(interior(Field(op), :, :, 1)) == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)] + + one_index_kernel_function(k, grid) = 2k + op = KernelFunctionOperation{Nothing, Nothing, Center}(one_index_kernel_function, grid) + @test Array(interior(Field(op), 1, 1, :)) == [2k for k in 1:size(grid, 3)] + + q = CenterField(grid) + set!(q, 2) + interior_pattern_kernel_function(i, k, grid, q) = @inbounds q[i, 1, k] * i * k + op = KernelFunctionOperation{Center, Nothing, Center}(interior_pattern_kernel_function, grid, q) + @test Array(interior(Field(op), :, 1, :)) == [2 * i * k for i in 1:size(grid, 1), k in 1:size(grid, 3)] + + # Varargs kernel functions keep the full three-index convention... + varargs_kernel_function(arguments...) = arguments[1] + arguments[2] + arguments[3] + op = KernelFunctionOperation{Center, Center, Nothing}(varargs_kernel_function, grid) + @test Array(interior(Field(op), :, :, 1)) == [i + j + 1 for i in 1:size(grid, 1), j in 1:size(grid, 2)] + + # ... unless only the reduced call is applicable (here the typed grid argument rejects `grid ← k`) + reduced_varargs_kernel_function(i, j, grid::Oceananigans.Grids.AbstractGrid, arguments...) = i * j + length(arguments) + op = KernelFunctionOperation{Center, Center, Nothing}(reduced_varargs_kernel_function, grid) + @test Array(interior(Field(op), :, :, 1)) == [i * j for i in 1:size(grid, 1), j in 1:size(grid, 2)] + + # When the full three-index call is applicable it is preferred, even if a + # reduced-arity method also exists (heavily overloaded operators rely on this) + dual_kernel_function(i, j, grid) = i + j + dual_kernel_function(i, j, k, grid) = -7 + op = KernelFunctionOperation{Center, Center, Nothing}(dual_kernel_function, grid) + @test Array(interior(Field(op), :, :, 1)) == fill(-7, size(grid, 1), size(grid, 2)) + + # Spacing operators (e.g. Δx) have reduced-arity helper methods that must not be + # mistaken for the reduced form at reduced locations + @test Array(interior(Field(xspacings(grid, Center(), Center(), Center())), :, 1, 1)) == + [Oceananigans.Operators.Δx(i, 1, 1, grid, Center(), Center(), Center()) for i in 1:size(grid, 1)] + + # Three-index kernel functions at reduced locations still work + three_index_kernel_function(i, j, k, grid) = i + j + op = KernelFunctionOperation{Center, Center, Nothing}(three_index_kernel_function, grid) + @test Array(interior(compute!(Field(op))))[:, :, 1] == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)] end @testset "Fidelity of simple binary operations" begin