Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions src/AbstractOperations/kernel_function_operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ Construct a `KernelFunctionOperation` at location `(LX, LY, LZ)` on `grid` with
kernel_function(i, j, k, grid, arguments...)
```

If the location contains `Nothing`, `kernel_function` may also omit the indices of the
`Nothing` dimensions: for example, at `(Center, Center, Nothing)` it may be called with

```julia
kernel_function(i, j, grid, arguments...)
```

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest including a full example that illustrates how this works including the KernelFunctionOperation{Center, Center, Nothing} constructor, to be fully explicit

The full three-index call is always preferred when it is applicable, so a function that
already accepts `(i, j, k, grid, arguments...)` keeps that behavior at every location; the
reduced call is used only when the full one is not applicable.

Note that `compute!(kfo::KernelFunctionOperation)` calls `compute!` on all `kfo.arguments`.

Examples
Expand Down Expand Up @@ -49,6 +60,20 @@ KernelFunctionOperation at (Face, Face, Center)
├── kernel_function: ζ₃ᶠᶠᶜ (generic function with 1 method)
└── arguments: ("Field", "Field")
```

Construct a `KernelFunctionOperation` at a reduced location using a kernel function
that omits the index of the `Nothing` dimension:

```jldoctest kfo
surface_kernel_function(i, j, grid) = i + j
surface_op = KernelFunctionOperation{Center, Center, Nothing}(surface_kernel_function, grid)

# output
KernelFunctionOperation at (Center, Center, ⋅)
├── grid: 1×8×8 RectilinearGrid{Float64, Periodic, Periodic, Bounded} on CPU with 1×3×3 halo
├── kernel_function: surface_kernel_function (generic function with 1 method)
└── arguments: ()
```
"""
struct KernelFunctionOperation{LX, LY, LZ, G, T, K, D} <: AbstractOperation{LX, LY, LZ, G, T}
kernel_function :: K
Expand All @@ -68,7 +93,21 @@ function KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, arguments...
return KernelFunctionOperation{LX, LY, LZ}(kernel_function, grid, tuple(arguments...))
end

@inline Base.getindex(κ::KernelFunctionOperation, i, j, k) = κ.kernel_function(i, j, k, κ.grid, κ.arguments...)
# `getindex` calls the kernel function with the full `(i, j, k, grid, args...)` signature
# whenever that call is applicable. At a reduced location it otherwise drops the indices of
# the `Nothing` dimensions, calling e.g. `kernel_function(i, j, grid, args...)`
@inline function Base.getindex(κ::KernelFunctionOperation{LX, LY, LZ}, i, j, k) where {LX, LY, LZ}
if applicable(κ.kernel_function, i, j, k, κ.grid, κ.arguments...)
return κ.kernel_function(i, j, k, κ.grid, κ.arguments...)
else
reduced_indices = (kept_index(LX, i)..., kept_index(LY, j)..., kept_index(LZ, k)...)
return κ.kernel_function(reduced_indices..., κ.grid, κ.arguments...)
end
end

@inline kept_index(::Type{Nothing}, index) = ()
@inline kept_index(::Type, index) = (index,)

indices(κ::KernelFunctionOperation) = construct_regionally(intersect_indices, location(κ), κ.arguments...)
compute_at!(κ::KernelFunctionOperation, time) = Tuple(compute_at!(d, time) for d in κ.arguments)

Expand Down Expand Up @@ -98,4 +137,4 @@ Base.show(io::IO, kfo::KernelFunctionOperation) =
Tuple(shortsummary(a) for a in kfo.arguments[1:end-1])...,
shortsummary(kfo.arguments[end])
end
)
)
42 changes: 42 additions & 0 deletions test/test_abstract_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,48 @@ for arch in archs
less_trivial_kernel_function(i, j, k, grid, u, v) = @inbounds u[i, j, k] * ℑxyᶠᶜᵃ(i, j, k, grid, v)
op = KernelFunctionOperation{Face, Center, Center}(less_trivial_kernel_function, grid, u, v)
@test op isa KernelFunctionOperation

two_index_kernel_function(i, j, grid) = i + j
op = KernelFunctionOperation{Center, Center, Nothing}(two_index_kernel_function, grid)
@test op isa KernelFunctionOperation
@test Array(interior(Field(op), :, :, 1)) == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)]

one_index_kernel_function(k, grid) = 2k
op = KernelFunctionOperation{Nothing, Nothing, Center}(one_index_kernel_function, grid)
@test Array(interior(Field(op), 1, 1, :)) == [2k for k in 1:size(grid, 3)]

q = CenterField(grid)
set!(q, 2)
interior_pattern_kernel_function(i, k, grid, q) = @inbounds q[i, 1, k] * i * k
op = KernelFunctionOperation{Center, Nothing, Center}(interior_pattern_kernel_function, grid, q)
@test Array(interior(Field(op), :, 1, :)) == [2 * i * k for i in 1:size(grid, 1), k in 1:size(grid, 3)]

# Varargs kernel functions keep the full three-index convention...
varargs_kernel_function(arguments...) = arguments[1] + arguments[2] + arguments[3]
op = KernelFunctionOperation{Center, Center, Nothing}(varargs_kernel_function, grid)
@test Array(interior(Field(op), :, :, 1)) == [i + j + 1 for i in 1:size(grid, 1), j in 1:size(grid, 2)]

# ... unless only the reduced call is applicable (here the typed grid argument rejects `grid ← k`)
reduced_varargs_kernel_function(i, j, grid::Oceananigans.Grids.AbstractGrid, arguments...) = i * j + length(arguments)
op = KernelFunctionOperation{Center, Center, Nothing}(reduced_varargs_kernel_function, grid)
@test Array(interior(Field(op), :, :, 1)) == [i * j for i in 1:size(grid, 1), j in 1:size(grid, 2)]

# When the full three-index call is applicable it is preferred, even if a
# reduced-arity method also exists (heavily overloaded operators rely on this)
dual_kernel_function(i, j, grid) = i + j
dual_kernel_function(i, j, k, grid) = -7
op = KernelFunctionOperation{Center, Center, Nothing}(dual_kernel_function, grid)
@test Array(interior(Field(op), :, :, 1)) == fill(-7, size(grid, 1), size(grid, 2))

# Spacing operators (e.g. Δx) have reduced-arity helper methods that must not be
# mistaken for the reduced form at reduced locations
@test Array(interior(Field(xspacings(grid, Center(), Center(), Center())), :, 1, 1)) ==
[Oceananigans.Operators.Δx(i, 1, 1, grid, Center(), Center(), Center()) for i in 1:size(grid, 1)]

# Three-index kernel functions at reduced locations still work
three_index_kernel_function(i, j, k, grid) = i + j
op = KernelFunctionOperation{Center, Center, Nothing}(three_index_kernel_function, grid)
@test Array(interior(compute!(Field(op))))[:, :, 1] == [i + j for i in 1:size(grid, 1), j in 1:size(grid, 2)]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to call compute!(Field(op)), because compute! is already called within Field(op).

The indexing is also messed up, so this can be interior(Field(op), :, :, 1) |> Array or something.

end

@testset "Fidelity of simple binary operations" begin
Expand Down