diff --git a/lor_bug_report.py b/lor_bug_report.py new file mode 100644 index 000000000..2f2b4dd48 --- /dev/null +++ b/lor_bug_report.py @@ -0,0 +1,240 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Summarize LOR_bug diagnostic output from PR #717. + +PR #717 instruments mpisppy/spbase.py::SPBase.allreduce_or to print a 4-line +block on every call (from cyl_rk == 0 of the cylinder's mpicomm). This script +parses such a log and writes a short report on the four hypotheses being +tested: + + H1. self.mpicomm has wider membership than the cylinder it should. + H2. Buffer memory underneath local_val was nonzero / non-boolean. + H3. The Allreduce reducer path is malfunctioning. + H4. Duplicate rank participation in self.mpicomm. + +Usage: + python lor_bug_report.py +""" + +import re +import sys +from collections import defaultdict + + +HEADER_RE = re.compile( + r"^\[LOR_bug call=(?P\d+) cls=(?P\S+) " + r"world_rk=(?P\d+) host=(?P\S+) pid=(?P\d+)\] " + r"mpicomm size=(?P\d+) name=(?P.+)$" +) +WORLD_RANKS_RE = re.compile( + r"^\s*world_ranks: min=(?P\d+) max=(?P\d+) " + r"count=(?P\d+) unique=(?P\d+)$" +) +REDUCTIONS_RE = re.compile( + r"^\s*reductions: sum=(?P-?\d+) max=(?P-?\d+) " + r"lor=(?P-?\d+) rank_sum=(?P-?\d+) " + r"expected_rank_sum=(?P-?\d+)$" +) +GATHER_RE = re.compile( + r"^\s*gather: gather_sum=(?P-?\d+) " + r"nonzero_reports=(?P\d+)$" +) + + +def parse(path): + """Return a list of dicts, one per [LOR_bug ...] block.""" + with open(path) as f: + lines = f.readlines() + + entries = [] + i = 0 + n = len(lines) + while i < n: + m = HEADER_RE.match(lines[i].rstrip()) + if not m: + i += 1 + continue + entry = { + "call": int(m["call"]), + "cls": m["cls"], + "world_rk": int(m["world_rk"]), + "host": m["host"], + "pid": int(m["pid"]), + "size": int(m["size"]), + "name": m["name"], + } + i += 1 + # The next three lines should be world_ranks / reductions / gather, + # in that order. Tolerate missing lines defensively. + for pat in (WORLD_RANKS_RE, REDUCTIONS_RE, GATHER_RE): + if i >= n: + break + mm = pat.match(lines[i].rstrip()) + if not mm: + break + for k, v in mm.groupdict().items(): + entry[k] = int(v) + i += 1 + entries.append(entry) + return entries + + +def _examples(rows, fields, n=5): + """Format up to n example rows, showing call ID + the listed fields.""" + out = [] + for e in rows[:n]: + extras = " ".join(f"{f}={e.get(f, '?')}" for f in fields) + out.append( + f" cls={e['cls']} call={e['call']} " + f"world_rk={e['world_rk']} host={e['host']} {extras}" + ) + if len(rows) > n: + out.append(f" (... {len(rows) - n} more truncated ...)") + return "\n".join(out) + + +# MPI implementations often leave new communicators with an empty or +# generic default name. When that happens, grouping by (cls, name) can +# collapse distinct physical comms into one bucket and falsely trip H1. +_DEFAULT_COMM_NAMES = {"''", '""', "'MPI_COMM_WORLD'", "'MPI_COMMUNICATOR'", + "", "''"} + + +def report(entries, path): + print(f"LOR_bug report for: {path}") + print(f"Parsed {len(entries)} [LOR_bug ...] blocks.") + if not entries: + print("\nNo diagnostic blocks found. Was the run on the LOR_bug branch?") + return + + # ---------- per-comm summary ---------- + by_comm = defaultdict(list) + for e in entries: + by_comm[(e["cls"], e["name"])].append(e) + + print("\nPer-comm summary (one printer per comm; cyl_rk == 0 only):") + for (cls, name), es in sorted(by_comm.items()): + sizes = sorted({e["size"] for e in es}) + wrs = sorted({e["world_rk"] for e in es}) + print(f" cls={cls} name={name}") + print(f" calls={len(es)} sizes={sizes} printer_world_rk={wrs}") + + # ---------- H1: wider membership ---------- + # Signal: size varies within a single (cls, name) bucket, OR printer + # world_rk varies across calls for the same logical comm (meaning + # different ranks took the "rank 0" role — only possible if comm + # membership shifted). + print("\nH1 — wider mpicomm membership than expected:") + h1_hits = [] + for (cls, name), es in by_comm.items(): + sizes = {e["size"] for e in es} + printers = {e["world_rk"] for e in es} + if len(sizes) > 1 or len(printers) > 1: + h1_hits.append((cls, name, sorted(sizes), sorted(printers))) + if h1_hits: + print(" WARNING: comm membership is not stable across calls:") + for cls, name, sizes, printers in h1_hits: + print(f" cls={cls} name={name} sizes={sizes} " + f"printer_world_rks={printers}") + else: + print(" OK: every comm has a stable size and stable rank-0 printer.") + defaulted = sorted({n for (_, n) in by_comm if n in _DEFAULT_COMM_NAMES}) + if defaulted: + print(f" NOTE: some comms have default/empty names ({defaulted}); " + "distinct physical comms may collapse into one bucket here " + "and produce spurious H1 hits. Check `printer_world_rk` " + "in the per-comm summary above.") + + # Also: if two different comms share the same printer world rank, that + # rank straddles two cylinders -- possible cross-cylinder contamination. + printer_to_comms = defaultdict(set) + for (cls, name), es in by_comm.items(): + for e in es: + printer_to_comms[e["world_rk"]].add((cls, name)) + shared = {wr: cs for wr, cs in printer_to_comms.items() if len(cs) > 1} + if shared: + print(" NOTE: world ranks acting as printer for multiple comms:") + for wr, cs in sorted(shared.items()): + print(f" world_rk={wr} comms={sorted(cs)}") + + # ---------- H2: buffer aliasing / non-boolean input ---------- + # Signature per PR description: nonzero local_val where it should be 0. + # The unambiguous tell is max > 1 (input was not a Python bool). + print("\nH2 — buffer aliasing / non-boolean input:") + nonbool = [e for e in entries if e.get("max", 0) > 1] + nonzero = [e for e in entries if e.get("gather_sum", 0) > 0] + print(f" Calls with any nonzero local_val: {len(nonzero)} / {len(entries)}" + f" (these may be legitimate True returns)") + if nonbool: + print(f" STRONG SIGNAL: {len(nonbool)} calls had max > 1 " + f"(input was not boolean)") + print(_examples(nonbool, ["max", "gather_sum"])) + else: + print(" OK: every nonzero local_val was 1 (boolean).") + + # ---------- H3: reducer malfunction ---------- + # (a) Allreduce SUM disagrees with the Allgather-summed local_vals. + # (b) rank_sum != expected sum-of-ranks for a comm of this size. + print("\nH3 — Allreduce reducer malfunction:") + sum_mismatch = [e for e in entries + if "sum" in e and "gather_sum" in e + and e["sum"] != e["gather_sum"]] + rank_sum_fail = [e for e in entries + if "rank_sum" in e and "expected_rank_sum" in e + and e["rank_sum"] != e["expected_rank_sum"]] + print(f" sum != gather_sum (reducer disagreeing with gather): " + f"{len(sum_mismatch)}") + if sum_mismatch: + print(_examples(sum_mismatch, ["sum", "gather_sum"])) + print(f" rank_sum sanity failures (SUM broken on this comm): " + f"{len(rank_sum_fail)}") + if rank_sum_fail: + print(_examples(rank_sum_fail, ["rank_sum", "expected_rank_sum"])) + + # ---------- H4: duplicate rank participation ---------- + print("\nH4 — duplicate rank participation in mpicomm:") + dups = [e for e in entries + if "unique" in e and "count" in e and e["unique"] < e["count"]] + print(f" Calls with duplicate world ranks: {len(dups)}") + if dups: + print(_examples(dups, ["count", "unique"])) + + # ---------- Verdict ---------- + print("\nVerdict:") + triggered = [] + if h1_hits: + triggered.append("H1 (wider/unstable membership)") + if nonbool: + triggered.append("H2 (non-boolean input)") + if sum_mismatch or rank_sum_fail: + triggered.append("H3 (reducer)") + if dups: + triggered.append("H4 (duplicate ranks)") + if triggered: + print(" Hypotheses triggered: " + ", ".join(triggered)) + else: + if nonzero: + print(" No invariant violations. Some calls returned nonzero;" + " consistent with legitimate shutdown signals.") + else: + print(" Clean log: no anomalies on any of the four hypotheses.") + + +def main(argv): + if len(argv) != 2: + print(f"Usage: {argv[0]} ", file=sys.stderr) + return 2 + path = argv[1] + entries = parse(path) + report(entries, path) + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/mpisppy/__init__.py b/mpisppy/__init__.py index a2b5cea08..00f045eae 100644 --- a/mpisppy/__init__.py +++ b/mpisppy/__init__.py @@ -22,3 +22,29 @@ def global_toc(msg, cond=_global_rank == 0): return tt_timer.toc(msg, delta=False) if cond else None global_toc("Initializing mpi-sppy") + + +def git_commit_hash(): + """DEBUG (LOR_bug): short hash of the running source checkout. + + Claude and the cluster experiments run on different machines; printing + the commit the experiment is actually running removes confusion about + which version produced a given output. Returns "unknown" outside a git + checkout (e.g. an installed package). Remove with the LOR_bug + instrumentation. + """ + import os + import subprocess + repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + try: + sha = subprocess.check_output( + ["git", "-C", repo_dir, "rev-parse", "--short=12", "HEAD"], + stderr=subprocess.DEVNULL, + ).decode().strip() + dirty = subprocess.check_output( + ["git", "-C", repo_dir, "status", "--porcelain", "--untracked-files=no"], + stderr=subprocess.DEVNULL, + ).decode().strip() + return sha + ("-dirty" if dirty else "") + except Exception: + return "unknown" diff --git a/mpisppy/cylinders/spcommunicator.py b/mpisppy/cylinders/spcommunicator.py index f9569db69..03179c187 100644 --- a/mpisppy/cylinders/spcommunicator.py +++ b/mpisppy/cylinders/spcommunicator.py @@ -87,6 +87,19 @@ def reduce_source_write_ids(source_ids, strict: bool) -> int: return source_ids[0] if len(set(source_ids)) == 1 else -1 return min(source_ids) +# ===== DEBUG (LOR_bug): canary guard appended to every field buffer ===== +# We allocate _CANARY_GUARD_DOUBLES extra doubles immediately after the +# padded window region and fill them with a recognizable sentinel. The +# window/RMA view (full_arr) is still exactly padded_len, so MPI transfers +# are unchanged, but any over-write that runs past the field's end lands in +# the guard (which we check every iteration) instead of silently corrupting +# the adjacent glibc tcache. This localizes the over-write to a specific +# field/rank/iteration at the moment it happens, rather than at a later +# malloc. Remove with the rest of the LOR_bug instrumentation. +_CANARY_GUARD_DOUBLES = 8 +_CANARY_VALUE = -123456789.0 + + def communicator_array(data_length: int): """ Allocate an MPI memory region with a padded length (multiple of 8 doubles = 64B), @@ -95,6 +108,7 @@ def communicator_array(data_length: int): Returns: full_arr: padded array (used for SPWindow put/get) + guard: trailing canary region (DEBUG; see _CANARY_VALUE) logical_arr: logical view (data + id), last element is id data_length: number of data entries (excluding id) logical_len: data_length + 1 @@ -104,15 +118,25 @@ def communicator_array(data_length: int): padded_len = padded_len_n_doubles(logical_len) itemsize = np.dtype("d").itemsize - mem = MPI.Alloc_mem(padded_len * itemsize) + backing_len = padded_len + _CANARY_GUARD_DOUBLES + mem = MPI.Alloc_mem(backing_len * itemsize) - full_arr = np.frombuffer(mem, dtype="d", count=padded_len) - full_arr[:] = np.nan + backing = np.frombuffer(mem, dtype="d", count=backing_len) + backing[:] = np.nan + + # Window/RMA view is exactly padded_len (MPI behavior unchanged). + full_arr = backing[:padded_len] + # Canary guard immediately after the window view. + guard = backing[padded_len:] + guard[:] = _CANARY_VALUE logical_arr = full_arr[:logical_len] logical_arr[-1] = 0.0 - return full_arr, logical_arr, data_length, logical_len, padded_len + # mem is returned so the caller can release it deterministically via + # MPI.Free_mem (see FieldArray.free); leaving it to garbage collection + # risks MPI.Free_mem running after MPI_Finalize at interpreter shutdown. + return mem, full_arr, guard, logical_arr, data_length, logical_len, padded_len class FieldArray: @@ -127,13 +151,47 @@ class FieldArray: def __init__(self, length: int): # length is the data length (excluding the id) - (self._full_array, + (self._mem, + self._full_array, + self._guard, self._array, self._data_length, self._logical_len, self._padded_len) = communicator_array(length) self._id = 0 + def canary_ok(self) -> bool: + """DEBUG (LOR_bug): True if the trailing guard region is intact. + + A breach means something wrote past this field's padded window + region -- the over-write that was corrupting the adjacent heap. + """ + if self._guard is None: + return True + return bool(np.all(self._guard == _CANARY_VALUE)) + + def free(self) -> None: + """Release the MPI-allocated backing memory deterministically. + + Copies the logical view onto the regular Python heap (so callers + that read final values after teardown -- e.g. ``spcomm.bound`` after + WheelSpinner.spin() -- still work), drops the RMA-only views, then + returns the MPI-allocated backing to MPI. After this the FieldArray + is read-only and must not be used for further RMA. Relying on + garbage collection instead risks MPI.Free_mem running after + MPI_Finalize at interpreter shutdown, which corrupts the heap. + """ + if self._mem is None: + return + # Detach the logical view from the MPI-backed buffer before freeing + # it; a plain copy keeps post-teardown reads valid without aliasing + # released MPI memory. + self._array = np.array(self._array) + self._full_array = None + self._guard = None + MPI.Free_mem(self._mem) + self._mem = None + def window_array(self) -> np.typing.NDArray: """Full padded array (used for SPWindow get/put).""" return self._full_array @@ -517,6 +575,14 @@ def free_windows(self) -> None: if self.window is None: return + # Release the MPI-allocated field buffers deterministically (before + # the window and before MPI_Finalize) rather than leaving them to + # garbage collection, whose timing relative to finalize is undefined. + for fa in self.receive_buffers.values(): + fa.free() + for fa in self.send_buffers.values(): + fa.free() + self.receive_buffers = {} self.send_buffers = {} self.receive_field_spcomms = {} @@ -560,6 +626,37 @@ def register_receive_fields(self) -> None: if self._flex_ranks and field not in _GLOBAL_OR_SCALAR_FIELDS: self._build_overlap_map(field, strata_rank) + def _report_canary_breach(self, where: str) -> None: + """DEBUG (LOR_bug): print a one-line breach record from any rank.""" + import sys + print( + f"[LOR_bug CANARY BREACH] {where} " + f"cls={type(self).__name__} global_rk={self.global_rank} " + f"cyl_rk={self.cylinder_rank} strata_rk={self.strata_rank}", + flush=True, + ) + sys.stdout.flush() + + def check_buffer_canaries(self, where: str = "") -> list: + """DEBUG (LOR_bug): sweep every send/recv buffer's canary guard. + + Returns a list of breached buffers (empty if clean) and prints a + record for each. Call this every iteration to catch an over-write + past a field buffer at the point it happens -- naming the field and + origin -- instead of at a later, unrelated malloc. + """ + bad = [] + for key, fa in self.receive_buffers.items(): + if not fa.canary_ok(): + fld, origin = self._split_key(key) + bad.append(f"recv field={fld.name} origin={origin}") + for fld, fa in self.send_buffers.items(): + if not fa.canary_ok(): + bad.append(f"send field={fld.name}") + if bad: + self._report_canary_breach(f"{where}: " + "; ".join(bad)) + return bad + def put_send_buffer(self, buf: SendArray, field: Field): """ Put the specified values into the specified locally-owned buffer for the another cylinder to pick up. @@ -569,6 +666,8 @@ def put_send_buffer(self, buf: SendArray, field: Field): """ buf._next_write_id() self.window.put(buf.window_array(), field) + if not buf.canary_ok(): + self._report_canary_breach(f"after put_send_buffer field={field.name}") return def _write_ids_agree(self, new_id: int, synchronize: bool) -> bool: @@ -644,6 +743,9 @@ def get_receive_buffer(self, last_id = buf.id() self.window.get(buf.window_array(), origin, field) # padded view + if not buf.canary_ok(): + self._report_canary_breach( + f"after get_receive_buffer field={field.name} origin={origin}") new_id = int(buf.array()[-1]) # logical view diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index c88da8451..26e3e2c53 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -28,6 +28,10 @@ def got_kill_signal(self): """ Spoke should call this method at least every iteration to see if the Hub terminated """ + # DEBUG (LOR_bug): sweep all field-buffer canaries once per iteration, + # before allreduce_or runs (whose allocations otherwise trip the + # already-corrupted heap). A breach names the over-written field. + self.check_buffer_canaries("got_kill_signal") shutdown_buf = self.receive_buffers[self._make_key(Field.SHUTDOWN, 0)] self.get_receive_buffer(shutdown_buf, Field.SHUTDOWN, 0, synchronize=False) fired = bool(shutdown_buf[0] == 1.0) diff --git a/mpisppy/cylinders/spwindow.py b/mpisppy/cylinders/spwindow.py index 003694262..1bb963840 100644 --- a/mpisppy/cylinders/spwindow.py +++ b/mpisppy/cylinders/spwindow.py @@ -169,8 +169,12 @@ def __init__(self, my_fields: dict, strata_comm: MPI.Comm, field_order=None): def free(self): if self.window is not None: - self.window.Free() + # Drop the numpy view of the window memory BEFORE freeing the + # window. self.buff aliases the window's memory via + # window.tomemory(); freeing the window first would leave buff + # dangling over released memory for the duration of the free. self.buff = None + self.window.Free() self.buffer_layout = None self.buffer_length = 0 self.window = None diff --git a/mpisppy/cylinders/xhatshufflelooper_bounder.py b/mpisppy/cylinders/xhatshufflelooper_bounder.py index b44eced3e..bbddc1171 100644 --- a/mpisppy/cylinders/xhatshufflelooper_bounder.py +++ b/mpisppy/cylinders/xhatshufflelooper_bounder.py @@ -13,6 +13,7 @@ from mpisppy.extensions.xhatbase import XhatBase from mpisppy.cylinders.xhatbase import XhatInnerBoundBase from mpisppy.cylinders._preloop_xhat_mixin import _PreLoopXhatMixin +from mpisppy.debug_utils.heap_probe import heap_probe # DEBUG(LOR_bug) # Could also pass, e.g., sys.stdout instead of a filename @@ -40,6 +41,9 @@ def xhat_prep(self): def try_scenario_dict(self, xhat_scenario_dict): """ wrapper for _try_one""" snamedict = xhat_scenario_dict + # DEBUG(LOR_bug): last marker before the xhat eval / gurobi set_objective + # where both 2026-06-11 runs aborted ("unaligned tcache chunk"). + heap_probe("xhatshuffle:pre-try_scenario_dict", rank=self.global_rank) stage2_ef_solver_name = self.opt.options.get("stage2_ef_solver_name", None) branching_factors = self.opt.options.get("branching_factors", None) # for stage2ef @@ -104,6 +108,8 @@ def _vb(msg): print("(rank0) " + msg) xh_iter = 1 + # DEBUG(LOR_bug): heap intact entering the cylinder loop. + heap_probe("xhatshuffle:main-enter", rank=self.global_rank) while not self.got_kill_signal(): # (unrelated: uncomment the next line to see the source of delay getting an xhat) if (xh_iter-1) % 100 == 0: @@ -111,6 +117,9 @@ def _vb(msg): logger.debug(f' Xhatshuffle got from opt on rank {self.global_rank}') new_nonants = self.update_nonants() + # DEBUG(LOR_bug): bracket the RMA receive of nonants from the hub. + heap_probe(f"xhatshuffle:post-update_nonants:iter{xh_iter}", + rank=self.global_rank) # When there is no iter0, the serial number must be checked. if self._nonant_len_receive_buffer.id() == 0: diff --git a/mpisppy/debug_utils/heap_probe.py b/mpisppy/debug_utils/heap_probe.py new file mode 100644 index 000000000..18659a5ab --- /dev/null +++ b/mpisppy/debug_utils/heap_probe.py @@ -0,0 +1,86 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2025, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Phase-boundary heap-integrity probe for the LOR_bug investigation (PR #717). + +OFF by default. Set env ``MPISPPY_LOR_HEAP_PROBE=1`` to activate. + +When active, ``heap_probe(label)`` forces the glibc allocator to walk its own +heap metadata: a two-pass malloc/free sweep across the tcache size classes plus +some fastbin/smallbin/largebin sizes, followed by ``malloc_trim(0)``. The first +pass frees chunks into the tcache bins; the second pass pulls them back out -- +exactly the ``tcache_get`` path that raises "malloc(): unaligned tcache chunk +detected". If the heap is already corrupted, glibc aborts INSIDE the probe +(SIGABRT, no Python traceback), pinning the corrupting write to the phase +between the last printed ``[LOR_bug HEAP PROBE OK]`` marker and this call site. + +Even when the probe does not actively abort, the OK markers bracket the +corruption: the last marker before the crash names the surviving phase. + +Numpy and MPI buffers live in the same glibc arena this probe walks, so a stomp +into an adjacent chunk's header is visible here. This is debug-only +instrumentation; remove before merging to main. +""" + +import os +import sys +import ctypes + +_ENABLED = bool(os.environ.get("MPISPPY_LOR_HEAP_PROBE")) + +# Per size class, pull a burst large enough to fully drain the tcache bin +# (glibc default depth is 7) and reach into the fastbins/smallbins behind it, +# so a poisoned entry sitting anywhere in the bin is pulled (and tripped) here. +_BURST = 16 + +_libc = None +_sizes = None + + +def enabled(): + return _ENABLED + + +def _init(): + global _libc, _sizes + if _libc is not None: + return + _libc = ctypes.CDLL("libc.so.6", use_errno=True) + _libc.malloc.restype = ctypes.c_void_p + _libc.malloc.argtypes = [ctypes.c_size_t] + _libc.free.argtypes = [ctypes.c_void_p] + _libc.malloc_trim.argtypes = [ctypes.c_size_t] + _libc.malloc_trim.restype = ctypes.c_int + # tcache bins cover ~16..1032 byte requests in 16-byte steps; append a few + # fastbin/smallbin/largebin sizes so malloc_trim's consolidation walks them. + _sizes = list(range(16, 1040, 16)) + [1536, 2048, 4096, 8192, 65536, 262144] + + +def heap_probe(label, rank=None): + """Walk the glibc heap; abort here if its metadata is already corrupted. + + No-op (just an env check) unless MPISPPY_LOR_HEAP_PROBE is set. + """ + if not _ENABLED: + return + _init() + # Drain every tcache/fastbin size class by pulling a burst from each (the + # tcache_get path that raises "unaligned tcache chunk detected" against any + # poisoned entry), holding them all so the bin actually empties, then free + # them back and trim. A single alloc/free per size would never reach a + # poisoned entry behind the bin head, so the burst depth is essential. + held = [] + for sz in _sizes: + for _ in range(_BURST): + held.append(_libc.malloc(sz)) + for p in held: + _libc.free(p) + _libc.malloc_trim(0) + tag = f" rank={rank}" if rank is not None else "" + print(f"[LOR_bug HEAP PROBE OK] {label}{tag}", flush=True) + sys.stderr.flush() diff --git a/mpisppy/spbase.py b/mpisppy/spbase.py index 27873763e..562875988 100644 --- a/mpisppy/spbase.py +++ b/mpisppy/spbase.py @@ -642,10 +642,139 @@ def spcomm(self, value): def allreduce_or(self, val): - local_val = np.array([val], dtype='int8') - global_val = np.zeros(1, dtype='int8') - self.mpicomm.Allreduce(local_val, global_val, op=MPI.LOR) - if global_val[0] > 0: + # ====== CONTROL toggle (LOR_bug) ====== + # When MPISPPY_LOR_CONTROL is set in the environment, bypass the + # instrumentation below (4 extra Allreduces + 1 Allgather + ~9 numpy + # allocations per call) and run the original minimal path. This lets + # the single PR branch produce a control data point -- to tell whether + # the heap corruption is real or an artifact of the diagnostic's + # collective volume -- without re-pushing. Canary guards and the + # teardown fixes are unaffected (they live elsewhere). See PR #717. + import os + if os.environ.get("MPISPPY_LOR_CONTROL"): + local_val = np.array([val], dtype='int8') + global_val = np.zeros(1, dtype='int8') + self.mpicomm.Allreduce(local_val, global_val, op=MPI.LOR) + return bool(global_val[0] > 0) + # ====== END CONTROL toggle ====== + # ====== DEBUG: LOR_bug instrumentation ====== + # Yields per call (on cyl_rk == 0 of self.mpicomm) the full picture + # needed to localize an Allreduce(LOR) returning nonzero when every + # rank intends a zero. Probes four axes: + # 1. comm membership (world ranks participating, size, uniqueness) + # 2. data going in (Allgather of every rank's local_val) + # 3. reduction sanity (SUM/MAX/LOR + a rank-sum check whose + # expected value is n*(n-1)/2) + # 4. consistency (compare Allgather sum to Allreduce SUM + # to tell "input was wrong" from + # "Allreduce is wrong") + # Remove before merging to main. See PR description for hypothesis tree. + import os + import socket + import sys + sz = self.mpicomm.Get_size() + cyl_rk = self.mpicomm.Get_rank() + world_rk = MPI.COMM_WORLD.Get_rank() + host = socket.gethostname() + pid = os.getpid() + + local_int = 1 if val else 0 + local_int32 = np.array([local_int], dtype='int32') + local_int8 = np.array([local_int], dtype='int8') + + # (3) Reductions — three ops in parallel, plus a rank-sum sanity check. + sum_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(local_int32, sum_out, op=MPI.SUM) + + max_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(local_int32, max_out, op=MPI.MAX) + + lor_out = np.zeros(1, dtype='int8') + self.mpicomm.Allreduce(local_int8, lor_out, op=MPI.LOR) + + rank_in = np.array([cyl_rk], dtype='int32') + rank_out = np.zeros(1, dtype='int32') + self.mpicomm.Allreduce(rank_in, rank_out, op=MPI.SUM) + expected_rank_sum = sz * (sz - 1) // 2 + + # (1) + (2) Allgather of (world_rk, cyl_rk, local_int) so we see + # exactly which ranks participated and what each one contributed. + report = np.array([world_rk, cyl_rk, local_int], dtype='int32') + all_reports = np.zeros(3 * sz, dtype='int32') + self.mpicomm.Allgather(report, all_reports) + + # Track a per-instance call counter so logs are correlatable. + self._lor_diag_count = getattr(self, "_lor_diag_count", 0) + 1 + call_n = self._lor_diag_count + + if cyl_rk == 0: + rows = all_reports.reshape(sz, 3) + wr = rows[:, 0].tolist() + nonzero_rows = rows[rows[:, 2] != 0] + gather_sum = int(rows[:, 2].sum()) + cls = type(self).__name__ + try: + comm_name = self.mpicomm.Get_name() + except Exception: + comm_name = "" + print( + f"[LOR_bug call={call_n} cls={cls} " + f"world_rk={world_rk} host={host} pid={pid}] " + f"mpicomm size={sz} name={comm_name!r}", + flush=True, + ) + print( + f" world_ranks: min={min(wr)} max={max(wr)} " + f"count={len(wr)} unique={len(set(wr))}", + flush=True, + ) + print( + f" reductions: sum={int(sum_out[0])} max={int(max_out[0])} " + f"lor={int(lor_out[0])} rank_sum={int(rank_out[0])} " + f"expected_rank_sum={expected_rank_sum}", + flush=True, + ) + print( + f" gather: gather_sum={gather_sum} " + f"nonzero_reports={len(nonzero_rows)}", + flush=True, + ) + # "Bad" = invariant-violating, NOT just "nonzero result." A + # legitimate shutdown signal returns sum=lor=1 with + # gather_sum=1 (consistent), which is fine. The real bug + # signature is gather_sum disagreeing with the Allreduce SUM + # (the reducer lying), or the rank-sum sanity check failing + # (SUM broken on this comm), or duplicate world ranks + # (group membership corrupted), or some rank packing >1 + # (non-boolean input — only possible under memory aliasing). + bad = ( + int(rank_out[0]) != expected_rank_sum + or int(sum_out[0]) != gather_sum + or len(set(wr)) != len(wr) + or int(max_out[0]) > 1 + ) + if bad: + limit = min(64, len(nonzero_rows)) + for w, c, v in nonzero_rows[:limit].tolist(): + print( + f" nonzero: world_rk={w} cyl_rk={c} local_val={v}", + flush=True, + ) + if len(nonzero_rows) > limit: + print( + f" (... {len(nonzero_rows) - limit} more nonzero rows truncated ...)", + flush=True, + ) + # Also dump the full world-rank list once so we can see exactly + # who is participating in this comm. + print( + f" ALL world_ranks: {wr}", + flush=True, + ) + sys.stdout.flush() + # ====== END DEBUG ====== + + if lor_out[0] > 0: return True else: return False diff --git a/mpisppy/spin_the_wheel.py b/mpisppy/spin_the_wheel.py index 558b07172..57f2fe144 100644 --- a/mpisppy/spin_the_wheel.py +++ b/mpisppy/spin_the_wheel.py @@ -8,11 +8,12 @@ ############################################################################### from pyomo.environ import value -from mpisppy import haveMPI, global_toc, MPI +from mpisppy import haveMPI, global_toc, git_commit_hash, MPI from mpisppy.utils import nice_join from mpisppy.utils.sputils import first_stage_nonant_writer, scenario_tree_solution_writer from mpisppy.utils.rank_apportionment import apportion_ranks, rank_to_cylinder +from mpisppy.debug_utils.heap_probe import heap_probe # DEBUG(LOR_bug) class WheelSpinner: @@ -84,6 +85,12 @@ def run(self, comm_world=None): if comm_world is None: comm_world = MPI.COMM_WORLD + # DEBUG (LOR_bug): announce the running source commit so output from + # the cluster can be matched to a version (Claude and the experiments + # run on different machines). Remove with the LOR_bug instrumentation. + global_toc(f"Running mpi-sppy commit {git_commit_hash()}", + comm_world.Get_rank() == 0) + _key_conversion = { "hub_class" : "spcomm_class", "hub_kwargs" : "spcomm_kwargs", @@ -159,6 +166,8 @@ def run(self, comm_world=None): # Create the appropriate opt object locally opt_kwargs["mpicomm"] = cylinder_comm opt = opt_class(**opt_kwargs) + # DEBUG(LOR_bug): bracket scenario creation + SPFBBT presolve. + heap_probe("after-opt-create", rank=global_rank) # Create the SPCommunicator object (hub/spoke) with # the appropriate SPBase object attached @@ -166,12 +175,18 @@ def run(self, comm_world=None): communicator_list, **sp_kwargs) spcomm.make_windows() + # DEBUG(LOR_bug): bracket RMA window allocation. + heap_probe("after-make-windows", rank=global_rank) # Run main() if strata_rank == 0: spcomm.setup_hub() + # DEBUG(LOR_bug): bracket hub setup (hub rank only). + heap_probe("after-setup-hub", rank=global_rank) global_toc("Starting spcomm.main()", comm_world.rank == 0) + # DEBUG(LOR_bug): last marker before the cylinder loop starts. + heap_probe("pre-main", rank=global_rank) spcomm.main() if strata_rank == 0: # If this is the hub spcomm.send_terminate() @@ -188,6 +203,10 @@ def run(self, comm_world=None): ## give the hub the chance to catch new values spcomm.hub_finalize() + + # hub_finalize may issue RMA gets; ensure every rank is done with + # remote reads before any window is freed. + fullcomm.Barrier() spcomm.free_windows() fullcomm.Barrier() diff --git a/mpisppy/tests/test_buffer_inspect.py b/mpisppy/tests/test_buffer_inspect.py index 39df831e2..78ae9e815 100644 --- a/mpisppy/tests/test_buffer_inspect.py +++ b/mpisppy/tests/test_buffer_inspect.py @@ -426,6 +426,11 @@ def _make_spoke_stub(shutdown_buf, *, inspect_on=True, Spoke._inspect_buffers_on_shutdown, stub) stub._warn_if_buffer_bad = types.MethodType( Spoke._warn_if_buffer_bad, stub) + # got_kill_signal also sweeps field-buffer canaries (LOR_bug debug aid). + stub.check_buffer_canaries = types.MethodType( + Spoke.check_buffer_canaries, stub) + stub._report_canary_breach = types.MethodType( + Spoke._report_canary_breach, stub) return stub