diff --git a/.gitignore b/.gitignore index d80203e..cce7ef6 100644 --- a/.gitignore +++ b/.gitignore @@ -108,5 +108,9 @@ ENV/ .vscode/ .idea/ +# vim +*.swp +*.swo + /data/ /preprocess_toolbox/scratches/ diff --git a/preprocess_toolbox/base.py b/preprocess_toolbox/base.py index 4c74da3..2f2417d 100644 --- a/preprocess_toolbox/base.py +++ b/preprocess_toolbox/base.py @@ -144,7 +144,7 @@ def get_dataset(self, for var_filepaths in self.processed_files[vn].values()] logging.info("Got {} filenames to open dataset with!".format(len(var_files))) - logging.debug(pformat(var_files)) + # logging.debug(pformat(var_files)) # TODO: where's my parallel mfdataset please!? with (dask.config.set(**{'array.slicing.split_large_chunks': True})): diff --git a/preprocess_toolbox/cli.py b/preprocess_toolbox/cli.py index 95ca6d3..d278a1b 100644 --- a/preprocess_toolbox/cli.py +++ b/preprocess_toolbox/cli.py @@ -16,13 +16,14 @@ class ProcessingArgParser(BaseArgParser): def __init__(self, *args, + base_path="processed_data", **kwargs): super().__init__(*args, **kwargs) self.add_argument("source", type=str) self.add_argument("-p", "--destination-path", help="Folder that any output data collections will be put in", - type=str, default="processed_data") + type=str, default=base_path) def add_ref_ds(self): self.add_argument("reference", type=str) diff --git a/preprocess_toolbox/dataset/cli.py b/preprocess_toolbox/dataset/cli.py index c54bac0..056f478 100644 --- a/preprocess_toolbox/dataset/cli.py +++ b/preprocess_toolbox/dataset/cli.py @@ -14,7 +14,7 @@ def process_dataset(): - args = (ProcessingArgParser(). + args = (ProcessingArgParser(base_path="processed"). add_concurrency(). add_destination(). add_implementation(). @@ -36,6 +36,7 @@ def process_dataset(): anom_clim_splits=args.processing_splits, config_path=args.config, identifier=args.destination_id, + base_path=args.destination_path, # TODO: nomenclature is old here, lag and lead make sense in forecasting, but not in here # so this mapping should be revised throughout the library - we don't necessarily forecast! lag_time=args.split_head, diff --git a/preprocess_toolbox/dataset/process.py b/preprocess_toolbox/dataset/process.py index 2cda6ad..9f52582 100644 --- a/preprocess_toolbox/dataset/process.py +++ b/preprocess_toolbox/dataset/process.py @@ -150,6 +150,9 @@ def rotate_dataset(ref_file: os.PathLike, wind_cubes[vars_to_rotate[1]], angles, ) + if len(wind_cubes_r[vars_to_rotate[0]].shape) == 2 and len(wind_cubes_r[vars_to_rotate[1]].shape) == 2: + wind_cubes_r[vars_to_rotate[0]] = iris.util.new_axis(wind_cubes_r[vars_to_rotate[0]], "time") + wind_cubes_r[vars_to_rotate[1]] = iris.util.new_axis(wind_cubes_r[vars_to_rotate[1]], "time") except iris.exceptions.CoordinateNotFoundError: logging.exception("Failure to rotate due to coordinate issues. " "moving onto next file") diff --git a/preprocess_toolbox/interface.py b/preprocess_toolbox/interface.py index cddc5e3..b2552b1 100644 --- a/preprocess_toolbox/interface.py +++ b/preprocess_toolbox/interface.py @@ -34,7 +34,6 @@ def get_processor_implementation(config: os.PathLike) -> object: create_kwargs = dict(**remaining) logging.info("Attempting to instantiate {} with loaded configuration".format(implementation)) - logging.debug("Converted kwargs from the retrieved configuration: {}".format(create_kwargs)) return implementation(**create_kwargs) @@ -57,7 +56,6 @@ def get_processor_from_source(identifier: str, source_cfg: dict) -> object: create_kwargs = {k: v for k, v in source_cfg.items() if k not in ["dataset_config", "implementation"]} logging.info("Attempting to instantiate {} with loaded configuration".format(source_cfg["implementation"])) - logging.debug("Converted kwargs from the retrieved configuration: {}".format(create_kwargs)) return get_implementation(source_cfg["implementation"])( get_dataset_config_implementation(source_cfg["dataset_config"]), diff --git a/preprocess_toolbox/loader/cli.py b/preprocess_toolbox/loader/cli.py index cbeb61b..3822606 100644 --- a/preprocess_toolbox/loader/cli.py +++ b/preprocess_toolbox/loader/cli.py @@ -45,12 +45,15 @@ def add_sections(self): class MetaArgParser(LoaderArgParser): - def __init__(self): + def __init__(self, base_path="processed_data"): super().__init__() self.add_argument("ground_truth_dataset") self.add_argument("-p", "--destination-path", help="Folder that any output data collections will be put in", - type=str, default="processed_data") + type=str, default=base_path) + self.add_argument("-l", "--loader-path", + help="Path to the loader JSON config file to load", + type=str, default=None) def add_channel(self): self.add_argument("channel_name") @@ -74,6 +77,9 @@ def create(): channels=dict(), ) destination_path = get_config_filename(args) + destination_directory = os.path.dirname(destination_path) + if destination_directory: + os.makedirs(destination_directory, exist_ok=True) if not os.path.exists(destination_path): with open(destination_path, "w") as fh: @@ -129,22 +135,31 @@ def add_processed(): def get_channel_info_from_processor(cfg_segment: str): - args = (MetaArgParser(). + args, unknown_args = (MetaArgParser(base_path="processed"). add_channel(). - parse_args()) + parse_known_args()) proc_impl = get_implementation(args.implementation) ds_config = get_dataset_config_implementation(args.ground_truth_dataset) if args.config is not None: - # FIXME: args.config contains the location of the dataset config on render, but - # this is not part of this pattern! DS is either ground truth or in derived class, - # but this library doesn't care or know of it respectively. - raise RuntimeError("--config-path is invalid for this CLI endpoint, sorry...") - - processor = proc_impl(ds_config, - [args.channel_name,], - args.channel_name) + # FIXME: args.config contains the location of the dataset config on render, but + # this is not part of this pattern! DS is either ground truth or in derived class, + # but this library doesn't care or know of it respectively. + raise RuntimeError("--config-path is invalid for this CLI endpoint, sorry...") + + impl_args = ( + ds_config, + [ + args.channel_name, + ], + args.channel_name, + ) + impl_kwargs = {"base_path": args.destination_path} + if unknown_args: + impl_kwargs |= unknown_args + + processor = proc_impl(*impl_args, **impl_kwargs) processor.process() update_config(get_config_filename(args), cfg_segment, diff --git a/preprocess_toolbox/processor.py b/preprocess_toolbox/processor.py index 4aa9778..92f8792 100644 --- a/preprocess_toolbox/processor.py +++ b/preprocess_toolbox/processor.py @@ -95,8 +95,16 @@ def __init__(self, self._normalisation_splits = [] if normalisation_splits is None else normalisation_splits self._parallel = parallel_opens self._refdir = ref_procdir + + ## + # Split dates - + # + # TODO: splits -> { dates, sources }, but currently sources are separate... self._splits = splits + self._dropped_split_dates = {} + # TODO: add self._dropped_dates based on DATA + self._source_files = dict() if init_source: @@ -223,45 +231,49 @@ def _init_source_data(self, :return: """ - split_dates_required = dict() drop_dates = dict() + all_dates = dict() for split in self._splits.keys(): - dates = sorted(self._splits[split]) + all_dates[split] = sorted(self._splits[split]) drop_dates[split] = list() - if dates: + if all_dates[split]: logging.info("Processing {} dates for {} category: {} - {}". - format(len(dates), split, min(dates), max(dates))) + format(len(all_dates[split]), split, min(all_dates[split]), max(all_dates[split]))) else: logging.info("No {} dates for this processor".format(split)) continue # Calculating lead and lag dates that aren't already accounted for in splits - if self._lag_time > 0: + if self._lag_time >= 0: logging.info("Including lag of {} {}s".format(self._lag_time, ds_config.frequency.attribute)) - additional_lag_dates, dropped_lag_dates = get_extension_dates(ds_config, dates, self._lag_time, reverse=True) - dates += additional_lag_dates + additional_lag_dates, dropped_lag_dates = get_extension_dates( + ds_config, all_dates[split], + # We offset by two, because -1 is channel one, so we need to account for lag == 1 being -2 + self._lag_time + 2, + start_step=1, reverse=True) + all_dates[split] += additional_lag_dates drop_dates[split] += dropped_lag_dates logging.info("Lag added {} dates for {} category: {} - {}". - format(len(dates), split, min(dates), max(dates))) + format(len(all_dates[split]), split, min(all_dates[split]), max(all_dates[split]))) if self._lead_time > 0: logging.info("Including lead of {} {}s".format(self._lead_time, ds_config.frequency.attribute)) - additional_lead_dates, dropped_lead_dates = get_extension_dates(ds_config, dates, self._lead_time) - dates += additional_lead_dates + additional_lead_dates, dropped_lead_dates = get_extension_dates(ds_config, all_dates[split], self._lead_time) + all_dates[split] += additional_lead_dates drop_dates[split] += dropped_lead_dates logging.info("Lead added {} dates for {} category: {} - {}". - format(len(dates), split, min(dates), max(dates))) + format(len(all_dates[split]), split, min(all_dates[split]), max(all_dates[split]))) - split_dates_required[split] = sorted([_ for _ in dates if _ not in drop_dates[split]]) + self._dropped_split_dates[split] = sorted(drop_dates[split]) + all_dates[split] = sorted([_ for _ in all_dates[split] if _ not in drop_dates[split]]) for split in self._splits.keys(): - self._source_files[split] = {var_config.name: ds_config.var_filepaths(var_config, split_dates_required[split]) + self._source_files[split] = {var_config.name: ds_config.var_filepaths(var_config, all_dates[split]) for var_config in ds_config.variables} for var_name, var_files in self._source_files[split].items(): logging.info("Got {} files for {}:{}".format(len(var_files), split, var_name)) - logging.debug(pformat(self._source_files)) def _normalise_array_mean(self, var_name: str, da: object, denormalise: bool=False): """ @@ -341,7 +353,6 @@ def _normalise_array_scaling(self, var_name: str, da: object, denormalise: bool= elif self.norm_split_dates: logging.debug("Generating norm-scaling min-max from {} training " "dates".format(len(self.norm_split_dates))) - norm_samples = da.sel(time=self.norm_split_dates).data norm_samples = norm_samples.ravel() @@ -375,7 +386,8 @@ def _process_channel(self, for split, var_files in self.source_files.items() for vn, files in var_files.items() for file in files - if var_name == vn]))) + if var_name == vn + and os.path.exists(file)]))) if len(source_files) > 0: logging.info("Opening {} files for {}".format(len(source_files), var_name)) @@ -384,22 +396,16 @@ def _process_channel(self, # data so this was harder. Now we work with whatever we get from download-toolbox ds = xr.open_mfdataset( source_files, - # Solves issue with inheriting files without - # time dimension (only having coordinate) - combine="nested", - concat_dim="time", - coords="minimal", - compat="override", - # TODO: review this, but if lat-lon is in the file, it's signalling bigger issues - # drop_variables=("lat", "lon"), - parallel=self._parallel) + engine="h5netcdf", + parallel=self._parallel, + lock=False) da = getattr(ds, var_name) da = da.astype(self.dtype) + logging.debug("Files to be opened: {}".format(da.dims)) # FIXME: we should ideally store train dates against the # normalisation and climatology, to ensure recalculation on # reprocess. All this need be is in the path, to be honest - if var_suffix == "anom": if len(self._anom_clim_splits) < 1 and self._refdir is None: raise ProcessingError("You must provide a list of splits via " @@ -497,6 +503,7 @@ def get_config(self, **kwargs): return { "implementation": "{}:{}".format(self.__module__, self.__class__.__name__), + "base_path": self._base_path, "anomoly_vars": self._anom_vars, "absolute_vars": self.abs_vars, "dataset_config": self._dataset_config, @@ -507,7 +514,9 @@ def get_config(self, **kwargs): "path": self.path, "processed_files": self._processed_files, "source_files": self._source_files, - "splits": self.splits, + "splits": {split: [ + date for date in dates if date not in self._dropped_split_dates[split] + ] for split, dates in self._splits.items()}, } @staticmethod @@ -588,7 +597,6 @@ def lead_time(self) -> int: @property def norm_split_dates(self): - # TODO: functools.cached_property, though slightly odd behaviour re. write-ability return [date for clim_split in self._normalisation_splits for date in self._splits[clim_split]] diff --git a/preprocess_toolbox/utils.py b/preprocess_toolbox/utils.py index 5ecc4dc..6507e84 100644 --- a/preprocess_toolbox/utils.py +++ b/preprocess_toolbox/utils.py @@ -6,8 +6,10 @@ from dateutil.relativedelta import relativedelta import orjson +import pandas as pd +import xarray as xr -from download_toolbox.interface import DatasetConfig +from download_toolbox.interface import DatasetConfig, Frequency def get_config(config_path: os.PathLike): @@ -23,6 +25,13 @@ def get_config_filename(args: argparse.Namespace, prefix: str = "loader"): if prefix is not None: default_loader_config = "{}.{}".format(prefix, default_loader_config) + if ( + "loader_path" in args + and args.loader_path is not None + and (os.path.isfile(args.loader_path) or not os.path.exists(args.loader_path)) + ): + return args.loader_path + # TODO: this is a bit grim, but to allow different config output paths it's very flexible. refactor if args.config is not None and (os.path.isfile(args.config) or not os.path.exists(args.config)): logging.warning("{} has been specified, overriding default name {}".format(args.config, args.name)) @@ -35,25 +44,43 @@ def get_config_filename(args: argparse.Namespace, prefix: str = "loader"): def get_extension_dates(ds_config: DatasetConfig, dates: list, num_steps: int, - reverse=False): + start_step: int = 0, + reverse: bool = False): additional_dates, dropped_dates = [], [] for date in dates: - for time in range(num_steps): - attrs = {"{}s".format(ds_config.frequency.attribute): time + 1} + for time in range(start_step, num_steps): + attrs = {"{}s".format(ds_config.frequency.attribute): time} op = operator.sub if reverse else operator.add extended_date = op(date, relativedelta(**attrs)) - if extended_date not in dates: - if all([os.path.exists(ds_config.var_filepath(var_config, [extended_date])) - for var_config in ds_config.variables]): - # We only add these dates into the mix if all necessary files exist - additional_dates.append(extended_date) + if ds_config.frequency == Frequency.MONTH: + extended_date = pd.to_datetime(extended_date + pd.offsets.MonthEnd(0)).date() + + # Check we don't know we have data, and also ignore previous occurrences + if extended_date not in dates and extended_date not in additional_dates: + extended_date_var_files = [ds_config.var_filepath(var_config, [extended_date]) + for var_config in ds_config.variables] + if all([os.path.exists(df) for df in extended_date_var_files]): + # The above will catch those items that fall outside the file output boundary, but not missing + # dates within ALL files. This next clause is more expensive, but necessary to catch everything! + logging.debug("Files exist, double checking whether {} appears in data itself across {} files". + format(extended_date, len(extended_date_var_files))) + + # TODO: this won't catch partially available dates where not all files have the date, but some do + if pd.Timestamp(extended_date) in xr.open_mfdataset(extended_date_var_files, compat="no_conflicts").time.values: + # We only add these dates into the mix if all necessary files exist + additional_dates.append(extended_date) + else: + logging.warning("Nope, {} not in data itself so dropping {}".format(extended_date, date)) + dropped_dates.append(date) + break else: # Otherwise, warn that the lag data means this is being dropped logging.warning("{} will be dropped due to missing data {}". format(date, extended_date)) dropped_dates.append(date) + break return sorted(list(set(additional_dates))), sorted(list(set(dropped_dates)))