-
Notifications
You must be signed in to change notification settings - Fork 55
Support datetime expressions in weather-sp output templates #536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,9 +13,10 @@ | |
| # limitations under the License. | ||
|
|
||
| from dataclasses import dataclass | ||
| import ast | ||
| import datetime as _datetime | ||
| import logging | ||
| import os | ||
| import string | ||
| import typing as t | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -24,6 +25,148 @@ | |
| NETCDF_FILE_ENDINGS = ('.nc', '.cd') | ||
|
|
||
|
|
||
| class _DateTimeNamespace: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can simplify the codebase by removing this class and helper methods, as we can remove the AST based approach! |
||
| date = _datetime.date | ||
| datetime = _datetime.datetime | ||
| time = _datetime.time | ||
| timedelta = _datetime.timedelta | ||
| timezone = _datetime.timezone | ||
| strptime = staticmethod(_datetime.datetime.strptime) | ||
|
|
||
|
|
||
| _TEMPLATE_GLOBALS = { | ||
| '__builtins__': {'__import__': __import__}, | ||
| 'datetime': _DateTimeNamespace, | ||
| 'float': float, | ||
| 'int': int, | ||
| 'str': str, | ||
| } | ||
|
|
||
|
|
||
| def _iter_template_fields(template: str) -> t.Iterator[t.Tuple[str, str]]: | ||
| """Yield literal text and replacement fields from a format template.""" | ||
| literal = [] | ||
| i = 0 | ||
| while i < len(template): | ||
| char = template[i] | ||
| if char == '{': | ||
| if i + 1 < len(template) and template[i + 1] == '{': | ||
| literal.append('{') | ||
| i += 2 | ||
| continue | ||
| if literal: | ||
| yield ''.join(literal), '' | ||
| literal = [] | ||
| field, i = _read_template_field(template, i + 1) | ||
| yield '', field | ||
| elif char == '}': | ||
| if i + 1 < len(template) and template[i + 1] == '}': | ||
| literal.append('}') | ||
| i += 2 | ||
| continue | ||
| raise ValueError("single '}' encountered in format string") | ||
| else: | ||
| literal.append(char) | ||
| i += 1 | ||
| if literal: | ||
| yield ''.join(literal), '' | ||
|
|
||
|
|
||
| def _read_template_field(template: str, start: int) -> t.Tuple[str, int]: | ||
| field = [] | ||
| quote = '' | ||
| escaped = False | ||
| depth = 0 | ||
| i = start | ||
| while i < len(template): | ||
| char = template[i] | ||
| if quote: | ||
| field.append(char) | ||
| if escaped: | ||
| escaped = False | ||
| elif char == '\\': | ||
| escaped = True | ||
| elif char == quote: | ||
| quote = '' | ||
| i += 1 | ||
| continue | ||
| if char in ('"', "'"): | ||
| quote = char | ||
| field.append(char) | ||
| elif char == '{': | ||
| depth += 1 | ||
| field.append(char) | ||
| elif char == '}': | ||
| if depth == 0: | ||
| return ''.join(field), i + 1 | ||
| depth -= 1 | ||
| field.append(char) | ||
| else: | ||
| field.append(char) | ||
| i += 1 | ||
| raise ValueError("expected '}' before end of string") | ||
|
|
||
|
|
||
| def _split_field(field: str) -> t.Tuple[str, str, str]: | ||
| quote = '' | ||
| escaped = False | ||
| depth = 0 | ||
| conversion_at = None | ||
| format_at = None | ||
| for i, char in enumerate(field): | ||
| if quote: | ||
| if escaped: | ||
| escaped = False | ||
| elif char == '\\': | ||
| escaped = True | ||
| elif char == quote: | ||
| quote = '' | ||
| continue | ||
| if char in ('"', "'"): | ||
| quote = char | ||
| elif char in '([{': | ||
| depth += 1 | ||
| elif char in ')]}': | ||
| depth -= 1 | ||
| elif depth == 0 and char == '!' and conversion_at is None and format_at is None: | ||
| conversion_at = i | ||
| elif depth == 0 and char == ':' and format_at is None: | ||
| format_at = i | ||
| break | ||
|
|
||
| field_name_end = min(x for x in (conversion_at, format_at, len(field)) if x is not None) | ||
| field_name = field[:field_name_end] | ||
| conversion = '' | ||
| if conversion_at is not None: | ||
| conversion_end = format_at if format_at is not None else len(field) | ||
| conversion = field[conversion_at + 1:conversion_end] | ||
| format_spec = field[format_at + 1:] if format_at is not None else '' | ||
| return field_name, conversion, format_spec | ||
|
|
||
|
|
||
| def _names_in_expression(expression: str) -> t.List[str]: | ||
| if expression.isdigit(): | ||
| return [] | ||
| if expression.isidentifier(): | ||
| return [expression] | ||
|
|
||
| tree = ast.parse(expression, mode='eval') | ||
| names = [] | ||
| for node in ast.walk(tree): | ||
| if isinstance(node, ast.Name) and node.id not in _TEMPLATE_GLOBALS: | ||
| names.append(node.id) | ||
| return names | ||
|
|
||
|
|
||
| def _validate_expression(expression: str) -> None: | ||
| tree = ast.parse(expression, mode='eval') | ||
| for node in ast.walk(tree): | ||
| if isinstance(node, ast.Name) and node.id.startswith('__'): | ||
| raise ValueError(f'Disallowed name in output template: {node.id!r}') | ||
| if isinstance(node, ast.Attribute) and node.attr.startswith('__'): | ||
| raise ValueError(f'Disallowed attribute in output template: {node.attr!r}') | ||
|
|
||
|
|
||
| @dataclass | ||
| class OutFileInfo: | ||
| """Holds data required to construct an output file name. | ||
|
|
@@ -50,13 +193,41 @@ def unformatted_output_path(self): | |
| return self.file_name_template + self.formatting + self.ending | ||
|
|
||
| def split_dims(self) -> t.List[str]: | ||
| all_format = list(filter(None, [field[1] for field in string.Formatter().parse( | ||
| self.unformatted_output_path())])) | ||
| return [key for key in all_format if not key.isdigit()] | ||
| dims = [] | ||
| for _, field in _iter_template_fields(self.unformatted_output_path()): | ||
| if not field: | ||
| continue | ||
| field_name, _, _ = _split_field(field) | ||
| for name in _names_in_expression(field_name): | ||
| if name not in dims: | ||
| dims.append(name) | ||
| return dims | ||
|
|
||
| def formatted_output_path(self, splits: t.Dict[str, str]) -> str: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update formatted_output_path to use native Python string formatting. |
||
| """Construct output file name with formatting applied""" | ||
| return self.unformatted_output_path().format(*self.template_folders, **splits) | ||
| output = [] | ||
| for literal, field in _iter_template_fields(self.unformatted_output_path()): | ||
| output.append(literal) | ||
| if not field: | ||
| continue | ||
| field_name, conversion, format_spec = _split_field(field) | ||
| if field_name.isdigit(): | ||
| value = self.template_folders[int(field_name)] | ||
| elif field_name.isidentifier(): | ||
| value = splits[field_name] | ||
| else: | ||
| _validate_expression(field_name) | ||
| value = eval(field_name, _TEMPLATE_GLOBALS, dict(splits)) | ||
| if conversion == 's': | ||
| value = str(value) | ||
| elif conversion == 'r': | ||
| value = repr(value) | ||
| elif conversion == 'a': | ||
| value = ascii(value) | ||
| elif conversion: | ||
| raise ValueError(f'Unknown conversion specifier {conversion!r}') | ||
| output.append(format(value, format_spec)) | ||
| return ''.join(output) | ||
|
|
||
|
|
||
| def get_output_file_info(filename: str, | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this approach, if a pipeline is initiated with an output-template using double quotes or without quotes, it can cause the pipeline to fail. We should handle these scenarios as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using a complex AST-based parser, can we implement a simple native wrapper for date and time that preserves the raw GRIB date/time field formats by default?