diff --git a/src/transformez/grid_engine.py b/src/transformez/grid_engine.py index 4baadf6..93a9a59 100644 --- a/src/transformez/grid_engine.py +++ b/src/transformez/grid_engine.py @@ -190,6 +190,10 @@ def smart_blend(in_grid, background_grid, blend_pixels=50): dist = ndimage.distance_transform_edt(mask) alpha = np.clip(dist / blend_pixels, 0.0, 1.0) + # --- Hermite Interpolation --- + # This converts the linear gradient into a smooth S-curve + alpha = alpha * alpha * (3.0 - 2.0 * alpha) + nearest_indices = ndimage.distance_transform_edt( mask, return_distances=False, return_indices=True ) @@ -210,26 +214,13 @@ def coastal_aware_composite( shapefiles=None, decay_pixels=100, buffer_pixels=10, - max_discontinuity=0.5, + blend_pixels=50, ): """Handles inland decay vs. offshore blending, while filtering out low-resolution global artifacts. """ final_grid = vdatum_grid.copy() - vdatum_mask = np.isnan(vdatum_grid) - if not vdatum_mask.all(): - nearest_idx = ndimage.distance_transform_edt( - vdatum_mask, return_distances=False, return_indices=True - ) - nearest_vdatum_vals = vdatum_grid[tuple(nearest_idx)] - - fes_anomaly_mask = ( - np.abs(global_grid - nearest_vdatum_vals) > max_discontinuity - ) - - global_grid[fes_anomaly_mask] = np.nan - land_mask = None if shapefiles: land_mask = GridEngine.create_land_mask(region, nx, ny, shapefiles) @@ -244,14 +235,13 @@ def coastal_aware_composite( if is_offshore.any(): blended_ocean = GridEngine.smart_blend( - vdatum_grid, global_grid, blend_pixels=50 + vdatum_grid, global_grid, blend_pixels=blend_pixels ) final_grid[is_offshore] = blended_ocean[is_offshore] if is_inland.any(): - source_for_decay = vdatum_grid if is_vdatum.any() else final_grid decayed_inland = GridEngine.fill_nans( - source_for_decay, + final_grid, decay_pixels=decay_pixels, buffer_pixels=buffer_pixels, land_mask=land_mask, diff --git a/src/transformez/transform.py b/src/transformez/transform.py index 9031f8a..73df39c 100644 --- a/src/transformez/transform.py +++ b/src/transformez/transform.py @@ -563,7 +563,6 @@ def _get_vdatum_chain(self, datum_name, geoid_name): shapefiles=coast_shapefiles, decay_pixels=self.decay_pixels, buffer_pixels=10, - max_discontinuity=5, ) desc.append(f"Blended w/ Global({proxy_name.upper()})") else: