Skip to content

Commit f6f0b96

Browse files
committed
global_evaluate: faithful parallel evaluate() — fix out-of-domain mislocation
global_evaluate's parallel path had no correct handling for query points outside the (old) domain: the evaluation-swarm migrate routes unclaimed points by nearest rank centre-of-mass and strands them on an arbitrary rank, which then extrapolates from a geometrically-far cell -> silently-wrong values, parallel-only (e.g. an annulus boundary point reading the opposite side). This corrupted mesh-variable transfer on parallel mover-adapted meshes. Restore the serial evaluate() contract (interpolate inside / extrapolate from the TRUE nearest cell outside / flag inside-outside) with a best-claim out-of-domain fallback in global_evaluate_nd: allgather the (small, boundary-layer) extrapolated set; every rank reports its nearest-local-cell distance + its LOCAL rbf extrapolation; Allreduce(MIN dist / MIN rank / SUM winner value) picks the globally-nearest rank's value. Only unconditional collectives + local rbf_evaluate (never the collective FE interpolation, which would desync) -> deadlock-safe. O(boundary points), no dense global tree. Default on; GE_LOCAL_FALLBACK=0 restores legacy. Serial unchanged (gated mpi.size>1). Validated: deterministic-rotation gate, linear field T=x, np=5, max_err 1.06 -> 0.003 (== serial). Underworld development team with AI support from Claude Code
1 parent 29a2029 commit f6f0b96

1 file changed

Lines changed: 107 additions & 0 deletions

File tree

src/underworld3/function/_function.pyx

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,16 @@ def global_evaluate_nd( expr,
368368
Users should typically use :func:`underworld3.function.global_evaluate`
369369
which provides automatic unit handling and a cleaner interface.
370370
371+
Contract: this is a faithful *parallel* counterpart of :func:`evaluate` —
372+
a query point is interpolated wherever in the mesh it lands (on any rank),
373+
a point just outside the mesh is extrapolated from its true nearest cell,
374+
and ``check_extrapolated`` returns an inside/outside flag per point. The
375+
result is independent of the number of ranks (up to the rank-local
376+
extrapolation residual near partition seams). Points that no rank can
377+
locate in-cell are resolved by a best-claim reduction over ranks (see the
378+
out-of-domain block below); set ``GE_LOCAL_FALLBACK=0`` to restore the
379+
legacy behaviour where such points returned silently-wrong values.
380+
371381
Note it is not efficient to call this function to evaluate an expression at
372382
a single coordinate. Instead the user should provide a numpy array of all
373383
coordinates requiring evaluation.
@@ -520,6 +530,103 @@ def global_evaluate_nd( expr,
520530
return_value[index, :, :] = data_container.array[:, :, :]
521531
return_mask[index] = is_extrapolated.array[:]
522532

533+
# ------------------------------------------------------------------
534+
# Out-of-domain extrapolation — keep the parallel result a faithful
535+
# match for the serial ``evaluate()`` contract: interpolate a point
536+
# wherever it lands across ranks, extrapolate a point just outside the
537+
# mesh, and flag inside/outside.
538+
#
539+
# After the migrate round-trip, a query point that NO rank could locate
540+
# in one of its cells returns flagged-extrapolated but valued from
541+
# whichever rank the bare dm.migrate happened to strand it on — typically
542+
# a geometrically far, WRONG cell (the classic symptom is an annulus
543+
# boundary point reading a value from the opposite side of the domain).
544+
# Serial ``evaluate()`` instead extrapolates from the TRUE nearest cell.
545+
# Restore that contract with a "best-claim" reduction over the (small,
546+
# boundary-layer) stranded set:
547+
#
548+
# 1. allgather the extrapolated points so every rank holds the SAME
549+
# global set;
550+
# 2. each rank reports, per point, its nearest-local-cell distance and
551+
# its LOCAL rbf extrapolation of the field there;
552+
# 3. Allreduce(MIN distance) + Allreduce(MIN rank) tie-break picks the
553+
# rank whose nearest cell is globally closest, and Allreduce(SUM of
554+
# the winner-only value/flag) scatters that rank's extrapolation back.
555+
#
556+
# A point some rank actually contains (distance ~ 0) naturally wins, so
557+
# only genuinely-stranded points are corrected. Cost is O(boundary points)
558+
# — no dense global tree, no exhaustive search.
559+
#
560+
# DEADLOCK SAFETY — read before editing. Every collective here (allgather,
561+
# Allreduce) runs unconditionally on the IDENTICAL global set on every
562+
# rank, so all ranks stay in lockstep (n_ext_total is itself a reduced
563+
# value, so the `> 0` guard is taken identically everywhere). The per-rank
564+
# value MUST come from the LOCAL rbf path (rbf=True): the FE interpolation
565+
# path (petsc_interpolate / DMInterpolation) is itself collective and would
566+
# desync here, because each rank classifies the same global set against its
567+
# own domain (different interior-point counts) → hang. Never route the
568+
# fallback value through FE interpolation.
569+
#
570+
# Serial is left untouched (the serial path above already extrapolates from
571+
# the true nearest cell). Escape hatch: GE_LOCAL_FALLBACK=0 restores the
572+
# legacy (silently-wrong out-of-domain) behaviour; default on.
573+
# ------------------------------------------------------------------
574+
import os
575+
if uw.mpi.size > 1 and os.environ.get("GE_LOCAL_FALLBACK", "1") not in (
576+
"0", "off", "false", "no", ""):
577+
from mpi4py import MPI
578+
579+
comm = uw.mpi.comm
580+
ext_idx = np.where(return_mask[:, 0, 0])[0]
581+
ext_coords = np.ascontiguousarray(coords_array[ext_idx], dtype=np.double)
582+
583+
counts = np.array(comm.allgather(ext_coords.shape[0]), dtype=int)
584+
n_ext_total = int(counts.sum())
585+
586+
if n_ext_total > 0:
587+
parts = comm.allgather(ext_coords)
588+
all_ext = np.concatenate(
589+
[p for p in parts if p.size], axis=0).reshape(n_ext_total, -1)
590+
591+
# This rank's local rbf extrapolation of the global set. NON-collective
592+
# value path — see DEADLOCK SAFETY above (must be rbf=True, never FE).
593+
ext_vals, ext_flag = evaluate_nd(
594+
expr, all_ext, rbf=True, evalf=False, verbose=False,
595+
check_extrapolated=True,)
596+
ext_vals = np.ascontiguousarray(
597+
np.asarray(ext_vals, dtype=np.double).reshape((n_ext_total,) + expr_shape))
598+
ext_flag = np.asarray(ext_flag).reshape(n_ext_total).astype(np.int32)
599+
600+
# Nearest-local-cell distance for every point (local kd-tree query).
601+
mesh._build_kd_tree_index()
602+
dist2, _ = mesh._centroid_index.query(all_ext, k=1, sqr_dists=True)
603+
dist2 = np.ascontiguousarray(np.asarray(dist2, dtype=np.double).ravel())
604+
605+
# Globally-nearest cell per point, lowest rank as the tie-break.
606+
min_dist2 = np.empty(n_ext_total, dtype=np.double)
607+
comm.Allreduce([dist2, MPI.DOUBLE], [min_dist2, MPI.DOUBLE], op=MPI.MIN)
608+
my_claim = np.where(dist2 <= min_dist2 * (1.0 + 1e-12) + 1e-300,
609+
comm.rank, comm.size).astype(np.int32)
610+
win_rank = np.empty(n_ext_total, dtype=np.int32)
611+
comm.Allreduce([my_claim, MPI.INT], [win_rank, MPI.INT], op=MPI.MIN)
612+
i_win = (win_rank == comm.rank)
613+
614+
# Winner contributes value+flag, everyone else zero; SUM selects it.
615+
contrib_val = np.ascontiguousarray(
616+
np.where(i_win[:, None, None], ext_vals, 0.0))
617+
best_val = np.empty_like(contrib_val)
618+
comm.Allreduce([contrib_val, MPI.DOUBLE], [best_val, MPI.DOUBLE], op=MPI.SUM)
619+
contrib_flag = np.where(i_win, ext_flag, 0).astype(np.int32)
620+
best_flag = np.empty(n_ext_total, dtype=np.int32)
621+
comm.Allreduce([contrib_flag, MPI.INT], [best_flag, MPI.INT], op=MPI.SUM)
622+
623+
# Scatter this rank's segment of the global set back to its points.
624+
offset = int(counts[:comm.rank].sum())
625+
seg = slice(offset, offset + ext_coords.shape[0])
626+
if ext_idx.size:
627+
return_value[ext_idx, :, :] = best_val[seg]
628+
return_mask[ext_idx, 0, 0] = best_flag[seg].astype(bool)
629+
523630
if not check_extrapolated:
524631
return return_value
525632
else:

0 commit comments

Comments
 (0)