Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 176 additions & 5 deletions weather_sp/splitter_pipeline/file_name_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

from dataclasses import dataclass
import ast

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using a complex AST-based parser, can we implement a simple native wrapper for date and time that preserves the raw GRIB date/time field formats by default?

import datetime as _datetime
import logging
import os
import string
import typing as t

logger = logging.getLogger(__name__)
Expand All @@ -24,6 +25,148 @@
NETCDF_FILE_ENDINGS = ('.nc', '.cd')


class _DateTimeNamespace:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can simplify the codebase by removing this class and helper methods, as we can remove the AST based approach!

date = _datetime.date
datetime = _datetime.datetime
time = _datetime.time
timedelta = _datetime.timedelta
timezone = _datetime.timezone
strptime = staticmethod(_datetime.datetime.strptime)


_TEMPLATE_GLOBALS = {
'__builtins__': {'__import__': __import__},
'datetime': _DateTimeNamespace,
'float': float,
'int': int,
'str': str,
}


def _iter_template_fields(template: str) -> t.Iterator[t.Tuple[str, str]]:
"""Yield literal text and replacement fields from a format template."""
literal = []
i = 0
while i < len(template):
char = template[i]
if char == '{':
if i + 1 < len(template) and template[i + 1] == '{':
literal.append('{')
i += 2
continue
if literal:
yield ''.join(literal), ''
literal = []
field, i = _read_template_field(template, i + 1)
yield '', field
elif char == '}':
if i + 1 < len(template) and template[i + 1] == '}':
literal.append('}')
i += 2
continue
raise ValueError("single '}' encountered in format string")
else:
literal.append(char)
i += 1
if literal:
yield ''.join(literal), ''


def _read_template_field(template: str, start: int) -> t.Tuple[str, int]:
field = []
quote = ''
escaped = False
depth = 0
i = start
while i < len(template):
char = template[i]
if quote:
field.append(char)
if escaped:
escaped = False
elif char == '\\':
escaped = True
elif char == quote:
quote = ''
i += 1
continue
if char in ('"', "'"):
quote = char
field.append(char)
elif char == '{':
depth += 1
field.append(char)
elif char == '}':
if depth == 0:
return ''.join(field), i + 1
depth -= 1
field.append(char)
else:
field.append(char)
i += 1
raise ValueError("expected '}' before end of string")


def _split_field(field: str) -> t.Tuple[str, str, str]:
quote = ''
escaped = False
depth = 0
conversion_at = None
format_at = None
for i, char in enumerate(field):
if quote:
if escaped:
escaped = False
elif char == '\\':
escaped = True
elif char == quote:
quote = ''
continue
if char in ('"', "'"):
quote = char
elif char in '([{':
depth += 1
elif char in ')]}':
depth -= 1
elif depth == 0 and char == '!' and conversion_at is None and format_at is None:
conversion_at = i
elif depth == 0 and char == ':' and format_at is None:
format_at = i
break

field_name_end = min(x for x in (conversion_at, format_at, len(field)) if x is not None)
field_name = field[:field_name_end]
conversion = ''
if conversion_at is not None:
conversion_end = format_at if format_at is not None else len(field)
conversion = field[conversion_at + 1:conversion_end]
format_spec = field[format_at + 1:] if format_at is not None else ''
return field_name, conversion, format_spec


def _names_in_expression(expression: str) -> t.List[str]:
if expression.isdigit():
return []
if expression.isidentifier():
return [expression]

tree = ast.parse(expression, mode='eval')
names = []
for node in ast.walk(tree):
if isinstance(node, ast.Name) and node.id not in _TEMPLATE_GLOBALS:
names.append(node.id)
return names


def _validate_expression(expression: str) -> None:
tree = ast.parse(expression, mode='eval')
for node in ast.walk(tree):
if isinstance(node, ast.Name) and node.id.startswith('__'):
raise ValueError(f'Disallowed name in output template: {node.id!r}')
if isinstance(node, ast.Attribute) and node.attr.startswith('__'):
raise ValueError(f'Disallowed attribute in output template: {node.attr!r}')


@dataclass
class OutFileInfo:
"""Holds data required to construct an output file name.
Expand All @@ -50,13 +193,41 @@ def unformatted_output_path(self):
return self.file_name_template + self.formatting + self.ending

def split_dims(self) -> t.List[str]:
all_format = list(filter(None, [field[1] for field in string.Formatter().parse(
self.unformatted_output_path())]))
return [key for key in all_format if not key.isdigit()]
dims = []
for _, field in _iter_template_fields(self.unformatted_output_path()):
if not field:
continue
field_name, _, _ = _split_field(field)
for name in _names_in_expression(field_name):
if name not in dims:
dims.append(name)
return dims

def formatted_output_path(self, splits: t.Dict[str, str]) -> str:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update formatted_output_path to use native Python string formatting.

"""Construct output file name with formatting applied"""
return self.unformatted_output_path().format(*self.template_folders, **splits)
output = []
for literal, field in _iter_template_fields(self.unformatted_output_path()):
output.append(literal)
if not field:
continue
field_name, conversion, format_spec = _split_field(field)
if field_name.isdigit():
value = self.template_folders[int(field_name)]
elif field_name.isidentifier():
value = splits[field_name]
else:
_validate_expression(field_name)
value = eval(field_name, _TEMPLATE_GLOBALS, dict(splits))
if conversion == 's':
value = str(value)
elif conversion == 'r':
value = repr(value)
elif conversion == 'a':
value = ascii(value)
elif conversion:
raise ValueError(f'Unknown conversion specifier {conversion!r}')
output.append(format(value, format_spec))
return ''.join(output)


def get_output_file_info(filename: str,
Expand Down
26 changes: 26 additions & 0 deletions weather_sp/splitter_pipeline/file_name_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from .file_name_utils import get_output_file_info, OutFileInfo


DATETIME_OUT_PATTERN = (
'gs://test-output/splits/test-file_{date}-'
'{datetime.strptime(f"{time}", "%H%M").strftime("%H:%M:%S")}.grib'
)


class FileNameUtilsTest(unittest.TestCase):

def test_get_output_file_info_pattern(self):
Expand Down Expand Up @@ -97,3 +103,23 @@ def test_split_dims(self):
out_dir=None,
input_base_dir='ignored')
self.assertEqual(actual.split_dims(), ['variable'])

def test_formatted_output_path_supports_datetime_expression(self):
actual = get_output_file_info(
filename='gs://test-input/test-file.grib',
out_pattern=DATETIME_OUT_PATTERN,
out_dir=None,
input_base_dir='ignored')

self.assertEqual(
actual.formatted_output_path({'date': '20200101', 'time': '1200'}),
'gs://test-output/splits/test-file_20200101-12:00:00.grib')

def test_split_dims_includes_names_used_by_expression(self):
actual = get_output_file_info(
filename='gs://test-input/test-file.grib',
out_pattern=DATETIME_OUT_PATTERN,
out_dir=None,
input_base_dir='ignored')

self.assertEqual(actual.split_dims(), ['date', 'time'])
26 changes: 8 additions & 18 deletions weather_sp/splitter_pipeline/file_splitters.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this approach, if a pipeline is initiated with an output-template using double quotes or without quotes, it can cause the pipeline to fail. We should handle these scenarios as well.

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import itertools
import logging
import os
import re
import shutil
import string
import subprocess
import tempfile
import typing as t
Expand Down Expand Up @@ -196,10 +194,6 @@ class GribSplitterV2(GribSplitter):
See https://confluence.ecmwf.int/display/ECC/grib_copy.
"""

def replace_non_numeric_bracket(self, match: re.Match) -> str:
value = match.group(1)
return f"[{value}]" if not value.isdigit() else "{" + value + "}"

def split_data(self) -> None:
if not self.output_info.split_dims():
raise ValueError('No splitting specified in template.')
Expand All @@ -212,17 +206,9 @@ def split_data(self) -> None:
raise EnvironmentError(f'binary {name!r} is not available in the current environment!')

unformatted_output_path = self.output_info.unformatted_output_path()
prefix, _ = os.path.split(next(iter(string.Formatter().parse(unformatted_output_path)))[0])
_, tail = unformatted_output_path.split(prefix)

# Replace { with [ and } with ] only for non-numeric values inside {} of tail
output_str = re.sub(r'\{(\w+)\}', self.replace_non_numeric_bracket, tail)
output_template = output_str.format(*self.output_info.template_folders)

slash = '/'
delimiter = 'DELIMITER'
flat_output_template = output_template.replace('/', delimiter)
split_dims = self.output_info.split_dims()
flat_output_template = delimiter.join(f'[{dim}]' for dim in split_dims)
# Construct a string where each split dimension is "dim:s".
# This ensures dims like time are represented as 0600 instead of 600.
split_dims_arg = ','.join(f'{dim}:s' for dim in split_dims)
Expand Down Expand Up @@ -262,10 +248,14 @@ def split_data(self) -> None:
check=True)

self.logger.info('Uploading %r...', self.input_path)
output_paths_set = set(output_paths)
for flat_target in os.listdir(tmpdir):
dest_file_path = f'{prefix}{flat_target.replace(delimiter, slash)}'
self.logger.info([prefix, dest_file_path, local_file.name,
self.output_info.unformatted_output_path()])
splits = dict(zip(split_dims, flat_target.split(delimiter)))
dest_file_path = self.output_info.formatted_output_path(splits)
if dest_file_path not in output_paths_set:
continue
self.logger.info([dest_file_path, local_file.name,
unformatted_output_path])

copy(os.path.join(tmpdir, flat_target), dest_file_path)
self.logger.info('Finished uploading %r', self.input_path)
Expand Down
Loading