Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import subprocess
import tempfile
import textwrap
import time
from collections.abc import Callable
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -156,6 +157,7 @@ def run_beam_command(
process_line_callback: Callable[[str], None] | None = None,
working_directory: str | None = None,
is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None,
periodic_callback: Callable[[], None] | None = None,
) -> None:
"""
Run pipeline command in subprocess.
Expand All @@ -165,6 +167,9 @@ def run_beam_command(
stdout and stderr to detect job id
:param working_directory: Working directory
:param log: logger.
:param periodic_callback: Optional callback invoked roughly every 5 seconds while the
subprocess is running. Used by deferrable Dataflow operators to poll for a job ID
when the launcher does not emit a parseable job-ID line to stdout.
"""
log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))

Expand All @@ -179,6 +184,7 @@ def run_beam_command(
# Waits for Apache Beam pipeline to complete.
log.info("Start waiting for Apache Beam process to complete.")
reads = [proc.stderr, proc.stdout]
last_periodic_call = time.monotonic()
while True:
# Wait for at least one available fd.
readable_fds, _, _ = select.select(reads, [], [], 5)
Expand All @@ -191,6 +197,13 @@ def run_beam_command(
if is_dataflow_job_id_exist_callback and is_dataflow_job_id_exist_callback():
return

now = time.monotonic()
if periodic_callback and now - last_periodic_call >= 5:
periodic_callback()
last_periodic_call = now
if is_dataflow_job_id_exist_callback and is_dataflow_job_id_exist_callback():
return

if proc.poll() is not None:
break

Expand Down Expand Up @@ -228,6 +241,7 @@ def _start_pipeline(
process_line_callback: Callable[[str], None] | None = None,
working_directory: str | None = None,
is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None,
periodic_callback: Callable[[], None] | None = None,
) -> None:
cmd = [*command_prefix, f"--runner={self.runner}"]
if variables:
Expand All @@ -238,6 +252,7 @@ def _start_pipeline(
working_directory=working_directory,
log=self.log,
is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback,
periodic_callback=periodic_callback,
)

def start_python_pipeline(
Expand All @@ -250,6 +265,7 @@ def start_python_pipeline(
py_system_site_packages: bool = False,
process_line_callback: Callable[[str], None] | None = None,
is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None,
periodic_callback: Callable[[], None] | None = None,
):
"""
Start Apache Beam python pipeline.
Expand Down Expand Up @@ -319,6 +335,7 @@ def start_python_pipeline(
command_prefix=command_prefix,
process_line_callback=process_line_callback,
is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback,
periodic_callback=periodic_callback,
)

def start_java_pipeline(
Expand All @@ -328,6 +345,7 @@ def start_java_pipeline(
job_class: str | None = None,
process_line_callback: Callable[[str], None] | None = None,
is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None,
periodic_callback: Callable[[], None] | None = None,
) -> None:
"""
Start Apache Beam Java pipeline.
Expand All @@ -347,6 +365,7 @@ def start_java_pipeline(
command_prefix=command_prefix,
process_line_callback=process_line_callback,
is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback,
periodic_callback=periodic_callback,
)

def start_go_pipeline(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,32 @@ def is_dataflow_job_id_exist() -> bool:

return is_dataflow_job_id_exist

def __get_dataflow_job_id_poll_callback(
self,
dataflow_hook: DataflowHook,
job_name: str,
location: str,
project_id: str,
) -> Callable[[], None]:
"""Return a callback that polls Dataflow API for the job ID when stdout hasn't provided it yet."""

def poll() -> None:
if self.dataflow_job_id:
return
try:
resolved = dataflow_hook.fetch_job_id_by_name(
prefix_name=job_name.lower(),
location=location,
project_id=project_id,
)
if resolved:
self.log.info("Resolved Dataflow job ID via API lookup: %s", resolved)
self.dataflow_job_id = resolved
except Exception:
self.log.debug("Periodic Dataflow job ID lookup failed; will retry.", exc_info=True)

return poll


class BeamBasePipelineOperator(BaseOperator, BeamDataflowMixin, ABC):
"""
Expand Down Expand Up @@ -448,6 +474,14 @@ def execute_on_dataflow(self, context: Context):
if not self.dataflow_hook:
self.dataflow_hook = self.__set_dataflow_hook()

location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION
periodic_callback = self.__get_dataflow_job_id_poll_callback(
dataflow_hook=self.dataflow_hook,
job_name=self.dataflow_job_name,
location=location,
project_id=self.dataflow_config.project_id,
)

self.beam_hook.start_python_pipeline(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
Expand All @@ -457,9 +491,8 @@ def execute_on_dataflow(self, context: Context):
py_system_site_packages=self.py_system_site_packages,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
periodic_callback=periodic_callback,
)

location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION
DataflowJobLink.persist(
context=context,
region=self.dataflow_config.location,
Expand Down Expand Up @@ -647,12 +680,19 @@ def execute_on_dataflow(self, context: Context):

if not is_running:
self.pipeline_options["jobName"] = self.dataflow_job_name
periodic_callback = self.__get_dataflow_job_id_poll_callback(
dataflow_hook=self.dataflow_hook,
job_name=self.dataflow_job_name,
location=self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION,
project_id=self.dataflow_config.project_id,
)
self.beam_hook.start_java_pipeline(
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
periodic_callback=periodic_callback,
)
if self.dataflow_job_name and self.dataflow_config.location:
DataflowJobLink.persist(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_start_python_pipeline(self, mock_check_output, mock_runner):
working_directory=None,
log=ANY,
is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback,
periodic_callback=None,
)

@mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.35.0")
Expand Down Expand Up @@ -176,6 +177,7 @@ def test_start_python_pipeline_with_custom_interpreter(
working_directory=None,
log=ANY,
is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback,
periodic_callback=None,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -226,6 +228,7 @@ def test_start_python_pipeline_with_non_empty_py_requirements_and_without_system
is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback,
working_directory=None,
log=ANY,
periodic_callback=None,
)
mock_virtualenv.assert_called_once_with(
venv_directory=mock.ANY,
Expand Down Expand Up @@ -282,6 +285,7 @@ def test_start_java_pipeline(self, mock_runner):
working_directory=None,
log=ANY,
is_dataflow_job_id_exist_callback=None,
periodic_callback=None,
)

@mock.patch(BEAM_STRING.format("run_beam_command"))
Expand Down Expand Up @@ -311,6 +315,7 @@ def test_start_java_pipeline_with_job_class(self, mock_runner):
working_directory=None,
log=ANY,
is_dataflow_job_id_exist_callback=None,
periodic_callback=None,
)

@mock.patch(BEAM_STRING.format("shutil.which"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def test_exec_dataflow_runner(
py_system_site_packages=False,
process_line_callback=mock.ANY,
is_dataflow_job_id_exist_callback=mock.ANY,
periodic_callback=mock.ANY,
)

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
Expand All @@ -281,6 +282,40 @@ def test_exec_dataflow_runner__no_dataflow_job_name(
op.execute({})
assert op.dataflow_config.job_name == op.task_id

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner_periodic_callback_fetches_job_id(
self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock, py_options
):
"""When launcher stdout does not emit a job-ID line, the periodic_callback must poll
the Dataflow API via fetch_job_id_by_name and set dataflow_job_id on the operator."""
op = BeamRunPythonPipelineOperator(
runner="DataflowRunner",
**self.default_op_kwargs,
)
dataflow_hook_mock.return_value.fetch_job_id_by_name.return_value = JOB_ID

captured: dict = {}

def capture_start(**kwargs):
captured["periodic_callback"] = kwargs.get("periodic_callback")

beam_hook_mock.return_value.start_python_pipeline.side_effect = capture_start

op.execute({})

periodic_callback = captured.get("periodic_callback")
assert periodic_callback is not None, "periodic_callback was not passed to start_python_pipeline"

assert op.dataflow_job_id is None

periodic_callback()

assert op.dataflow_job_id == JOB_ID
dataflow_hook_mock.return_value.fetch_job_id_by_name.assert_called_once()

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
Expand Down Expand Up @@ -486,6 +521,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
job_class=JOB_CLASS,
process_line_callback=mock.ANY,
is_dataflow_job_id_exist_callback=mock.ANY,
periodic_callback=mock.ANY,
)

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
Expand All @@ -504,6 +540,41 @@ def test_exec_dataflow_runner__no_dataflow_job_name(
op.execute({})
assert op.dataflow_config.job_name == op.task_id

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner_periodic_callback_fetches_job_id(
self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock
):
"""When launcher stdout does not emit a job-ID line, the periodic_callback must poll
the Dataflow API via fetch_job_id_by_name and set dataflow_job_id on the operator."""
dataflow_config = DataflowConfiguration(impersonation_chain="test@impersonation.com")
op = BeamRunJavaPipelineOperator(
**self.default_op_kwargs, dataflow_config=dataflow_config, runner="DataflowRunner"
)
dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
dataflow_hook_mock.return_value.fetch_job_id_by_name.return_value = JOB_ID

captured: dict = {}

def capture_start(**kwargs):
captured["periodic_callback"] = kwargs.get("periodic_callback")

beam_hook_mock.return_value.start_java_pipeline.side_effect = capture_start

op.execute({})

periodic_callback = captured.get("periodic_callback")
assert periodic_callback is not None, "periodic_callback was not passed to start_java_pipeline"

assert op.dataflow_job_id is None

periodic_callback()

assert op.dataflow_job_id == JOB_ID
dataflow_hook_mock.return_value.fetch_job_id_by_name.assert_called_once()

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,29 @@ def is_job_done(self, location: str, project_id: str, job_id: str) -> bool:
job = job_controller.fetch_job_by_id(job_id)
return job_controller.job_reached_terminal_state(job)

@GoogleBaseHook.fallback_to_default_project_id
def fetch_job_id_by_name(
self,
prefix_name: str,
location: str,
project_id: str = PROVIDE_PROJECT_ID,
) -> str | None:
"""
Fetch the ID of the Dataflow job whose name matches the given prefix.

:param prefix_name: Job name prefix to look up.
:param location: Location of the Dataflow job.
:param project_id: Google Cloud project ID in which to look up the job.
:return: Job ID if a matching job is found, else None.
"""
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
location=location,
)
jobs = jobs_controller._fetch_jobs_by_prefix_name(prefix_name.lower())
return jobs[0]["id"] if jobs else None

@GoogleBaseHook.fallback_to_default_project_id
def create_data_pipeline(
self,
Expand Down
Loading