@@ -368,6 +368,16 @@ def global_evaluate_nd( expr,
368368 Users should typically use :func:`underworld3.function.global_evaluate`
369369 which provides automatic unit handling and a cleaner interface.
370370
371+ Contract: this is a faithful *parallel* counterpart of :func:`evaluate` —
372+ a query point is interpolated wherever in the mesh it lands (on any rank),
373+ a point just outside the mesh is extrapolated from its true nearest cell,
374+ and ``check_extrapolated`` returns an inside/outside flag per point. The
375+ result is independent of the number of ranks (up to the rank-local
376+ extrapolation residual near partition seams). Points that no rank can
377+ locate in-cell are resolved by a best-claim reduction over ranks (see the
378+ out-of-domain block below); set ``GE_LOCAL_FALLBACK=0`` to restore the
379+ legacy behaviour where such points returned silently-wrong values.
380+
371381 Note it is not efficient to call this function to evaluate an expression at
372382 a single coordinate. Instead the user should provide a numpy array of all
373383 coordinates requiring evaluation.
@@ -520,6 +530,103 @@ def global_evaluate_nd( expr,
520530 return_value[index, :, :] = data_container.array[:, :, :]
521531 return_mask[index] = is_extrapolated.array[:]
522532
533+ # ------------------------------------------------------------------
534+ # Out-of-domain extrapolation — keep the parallel result a faithful
535+ # match for the serial ``evaluate()`` contract: interpolate a point
536+ # wherever it lands across ranks, extrapolate a point just outside the
537+ # mesh, and flag inside/outside.
538+ #
539+ # After the migrate round-trip, a query point that NO rank could locate
540+ # in one of its cells returns flagged-extrapolated but valued from
541+ # whichever rank the bare dm.migrate happened to strand it on — typically
542+ # a geometrically far, WRONG cell (the classic symptom is an annulus
543+ # boundary point reading a value from the opposite side of the domain).
544+ # Serial ``evaluate()`` instead extrapolates from the TRUE nearest cell.
545+ # Restore that contract with a "best-claim" reduction over the (small,
546+ # boundary-layer) stranded set:
547+ #
548+ # 1. allgather the extrapolated points so every rank holds the SAME
549+ # global set;
550+ # 2. each rank reports, per point, its nearest-local-cell distance and
551+ # its LOCAL rbf extrapolation of the field there;
552+ # 3. Allreduce(MIN distance) + Allreduce(MIN rank) tie-break picks the
553+ # rank whose nearest cell is globally closest, and Allreduce(SUM of
554+ # the winner-only value/flag) scatters that rank's extrapolation back.
555+ #
556+ # A point some rank actually contains (distance ~ 0) naturally wins, so
557+ # only genuinely-stranded points are corrected. Cost is O(boundary points)
558+ # — no dense global tree, no exhaustive search.
559+ #
560+ # DEADLOCK SAFETY — read before editing. Every collective here (allgather,
561+ # Allreduce) runs unconditionally on the IDENTICAL global set on every
562+ # rank, so all ranks stay in lockstep (n_ext_total is itself a reduced
563+ # value, so the `> 0` guard is taken identically everywhere). The per-rank
564+ # value MUST come from the LOCAL rbf path (rbf=True): the FE interpolation
565+ # path (petsc_interpolate / DMInterpolation) is itself collective and would
566+ # desync here, because each rank classifies the same global set against its
567+ # own domain (different interior-point counts) → hang. Never route the
568+ # fallback value through FE interpolation.
569+ #
570+ # Serial is left untouched (the serial path above already extrapolates from
571+ # the true nearest cell). Escape hatch: GE_LOCAL_FALLBACK=0 restores the
572+ # legacy (silently-wrong out-of-domain) behaviour; default on.
573+ # ------------------------------------------------------------------
574+ import os
575+ if uw.mpi.size > 1 and os.environ.get(" GE_LOCAL_FALLBACK" , " 1" ) not in (
576+ " 0" , " off" , " false" , " no" , " " ):
577+ from mpi4py import MPI
578+
579+ comm = uw.mpi.comm
580+ ext_idx = np.where(return_mask[:, 0 , 0 ])[0 ]
581+ ext_coords = np.ascontiguousarray(coords_array[ext_idx], dtype = np.double)
582+
583+ counts = np.array(comm.allgather(ext_coords.shape[0 ]), dtype = int )
584+ n_ext_total = int (counts.sum())
585+
586+ if n_ext_total > 0 :
587+ parts = comm.allgather(ext_coords)
588+ all_ext = np.concatenate(
589+ [p for p in parts if p.size], axis = 0 ).reshape(n_ext_total, - 1 )
590+
591+ # This rank's local rbf extrapolation of the global set. NON-collective
592+ # value path — see DEADLOCK SAFETY above (must be rbf=True, never FE).
593+ ext_vals, ext_flag = evaluate_nd(
594+ expr, all_ext, rbf = True , evalf = False , verbose = False ,
595+ check_extrapolated = True ,)
596+ ext_vals = np.ascontiguousarray(
597+ np.asarray(ext_vals, dtype = np.double).reshape((n_ext_total,) + expr_shape))
598+ ext_flag = np.asarray(ext_flag).reshape(n_ext_total).astype(np.int32)
599+
600+ # Nearest-local-cell distance for every point (local kd-tree query).
601+ mesh._build_kd_tree_index()
602+ dist2, _ = mesh._centroid_index.query(all_ext, k = 1 , sqr_dists = True )
603+ dist2 = np.ascontiguousarray(np.asarray(dist2, dtype = np.double).ravel())
604+
605+ # Globally-nearest cell per point, lowest rank as the tie-break.
606+ min_dist2 = np.empty(n_ext_total, dtype = np.double)
607+ comm.Allreduce([dist2, MPI.DOUBLE], [min_dist2, MPI.DOUBLE], op = MPI.MIN)
608+ my_claim = np.where(dist2 <= min_dist2 * (1.0 + 1e-12 ) + 1e-300 ,
609+ comm.rank, comm.size).astype(np.int32)
610+ win_rank = np.empty(n_ext_total, dtype = np.int32)
611+ comm.Allreduce([my_claim, MPI.INT], [win_rank, MPI.INT], op = MPI.MIN)
612+ i_win = (win_rank == comm.rank)
613+
614+ # Winner contributes value+flag, everyone else zero; SUM selects it.
615+ contrib_val = np.ascontiguousarray(
616+ np.where(i_win[:, None , None ], ext_vals, 0.0 ))
617+ best_val = np.empty_like(contrib_val)
618+ comm.Allreduce([contrib_val, MPI.DOUBLE], [best_val, MPI.DOUBLE], op = MPI.SUM)
619+ contrib_flag = np.where(i_win, ext_flag, 0 ).astype(np.int32)
620+ best_flag = np.empty(n_ext_total, dtype = np.int32)
621+ comm.Allreduce([contrib_flag, MPI.INT], [best_flag, MPI.INT], op = MPI.SUM)
622+
623+ # Scatter this rank's segment of the global set back to its points.
624+ offset = int (counts[:comm.rank].sum())
625+ seg = slice (offset, offset + ext_coords.shape[0 ])
626+ if ext_idx.size:
627+ return_value[ext_idx, :, :] = best_val[seg]
628+ return_mask[ext_idx, 0 , 0 ] = best_flag[seg].astype(bool )
629+
523630 if not check_extrapolated:
524631 return return_value
525632 else :
0 commit comments