diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py index 9dd36bd25a7b6..393b2077ff498 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py @@ -29,6 +29,7 @@ import subprocess import tempfile import textwrap +import time from collections.abc import Callable from typing import TYPE_CHECKING @@ -156,6 +157,7 @@ def run_beam_command( process_line_callback: Callable[[str], None] | None = None, working_directory: str | None = None, is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None, + periodic_callback: Callable[[], None] | None = None, ) -> None: """ Run pipeline command in subprocess. @@ -165,6 +167,9 @@ def run_beam_command( stdout and stderr to detect job id :param working_directory: Working directory :param log: logger. + :param periodic_callback: Optional callback invoked roughly every 5 seconds while the + subprocess is running. Used by deferrable Dataflow operators to poll for a job ID + when the launcher does not emit a parseable job-ID line to stdout. """ log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd)) @@ -179,6 +184,7 @@ def run_beam_command( # Waits for Apache Beam pipeline to complete. log.info("Start waiting for Apache Beam process to complete.") reads = [proc.stderr, proc.stdout] + last_periodic_call = time.monotonic() while True: # Wait for at least one available fd. readable_fds, _, _ = select.select(reads, [], [], 5) @@ -191,6 +197,13 @@ def run_beam_command( if is_dataflow_job_id_exist_callback and is_dataflow_job_id_exist_callback(): return + now = time.monotonic() + if periodic_callback and now - last_periodic_call >= 5: + periodic_callback() + last_periodic_call = now + if is_dataflow_job_id_exist_callback and is_dataflow_job_id_exist_callback(): + return + if proc.poll() is not None: break @@ -228,6 +241,7 @@ def _start_pipeline( process_line_callback: Callable[[str], None] | None = None, working_directory: str | None = None, is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None, + periodic_callback: Callable[[], None] | None = None, ) -> None: cmd = [*command_prefix, f"--runner={self.runner}"] if variables: @@ -238,6 +252,7 @@ def _start_pipeline( working_directory=working_directory, log=self.log, is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback, + periodic_callback=periodic_callback, ) def start_python_pipeline( @@ -250,6 +265,7 @@ def start_python_pipeline( py_system_site_packages: bool = False, process_line_callback: Callable[[str], None] | None = None, is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None, + periodic_callback: Callable[[], None] | None = None, ): """ Start Apache Beam python pipeline. @@ -319,6 +335,7 @@ def start_python_pipeline( command_prefix=command_prefix, process_line_callback=process_line_callback, is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback, + periodic_callback=periodic_callback, ) def start_java_pipeline( @@ -328,6 +345,7 @@ def start_java_pipeline( job_class: str | None = None, process_line_callback: Callable[[str], None] | None = None, is_dataflow_job_id_exist_callback: Callable[[], bool] | None = None, + periodic_callback: Callable[[], None] | None = None, ) -> None: """ Start Apache Beam Java pipeline. @@ -347,6 +365,7 @@ def start_java_pipeline( command_prefix=command_prefix, process_line_callback=process_line_callback, is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback, + periodic_callback=periodic_callback, ) def start_go_pipeline( diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py index e42f1e1d60a39..91c5b96bef72a 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py @@ -158,6 +158,32 @@ def is_dataflow_job_id_exist() -> bool: return is_dataflow_job_id_exist + def __get_dataflow_job_id_poll_callback( + self, + dataflow_hook: DataflowHook, + job_name: str, + location: str, + project_id: str, + ) -> Callable[[], None]: + """Return a callback that polls Dataflow API for the job ID when stdout hasn't provided it yet.""" + + def poll() -> None: + if self.dataflow_job_id: + return + try: + resolved = dataflow_hook.fetch_job_id_by_name( + prefix_name=job_name.lower(), + location=location, + project_id=project_id, + ) + if resolved: + self.log.info("Resolved Dataflow job ID via API lookup: %s", resolved) + self.dataflow_job_id = resolved + except Exception: + self.log.debug("Periodic Dataflow job ID lookup failed; will retry.", exc_info=True) + + return poll + class BeamBasePipelineOperator(BaseOperator, BeamDataflowMixin, ABC): """ @@ -448,6 +474,14 @@ def execute_on_dataflow(self, context: Context): if not self.dataflow_hook: self.dataflow_hook = self.__set_dataflow_hook() + location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION + periodic_callback = self.__get_dataflow_job_id_poll_callback( + dataflow_hook=self.dataflow_hook, + job_name=self.dataflow_job_name, + location=location, + project_id=self.dataflow_config.project_id, + ) + self.beam_hook.start_python_pipeline( variables=self.snake_case_pipeline_options, py_file=self.py_file, @@ -457,9 +491,8 @@ def execute_on_dataflow(self, context: Context): py_system_site_packages=self.py_system_site_packages, process_line_callback=self.process_line_callback, is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, + periodic_callback=periodic_callback, ) - - location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION DataflowJobLink.persist( context=context, region=self.dataflow_config.location, @@ -647,12 +680,19 @@ def execute_on_dataflow(self, context: Context): if not is_running: self.pipeline_options["jobName"] = self.dataflow_job_name + periodic_callback = self.__get_dataflow_job_id_poll_callback( + dataflow_hook=self.dataflow_hook, + job_name=self.dataflow_job_name, + location=self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION, + project_id=self.dataflow_config.project_id, + ) self.beam_hook.start_java_pipeline( variables=self.pipeline_options, jar=self.jar, job_class=self.job_class, process_line_callback=self.process_line_callback, is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, + periodic_callback=periodic_callback, ) if self.dataflow_job_name and self.dataflow_config.location: DataflowJobLink.persist( diff --git a/providers/apache/beam/tests/unit/apache/beam/hooks/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/hooks/test_beam.py index e9750170280c0..4e55624a55e5f 100644 --- a/providers/apache/beam/tests/unit/apache/beam/hooks/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/hooks/test_beam.py @@ -112,6 +112,7 @@ def test_start_python_pipeline(self, mock_check_output, mock_runner): working_directory=None, log=ANY, is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback, + periodic_callback=None, ) @mock.patch("airflow.providers.apache.beam.hooks.beam.subprocess.check_output", return_value=b"2.35.0") @@ -176,6 +177,7 @@ def test_start_python_pipeline_with_custom_interpreter( working_directory=None, log=ANY, is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback, + periodic_callback=None, ) @pytest.mark.parametrize( @@ -226,6 +228,7 @@ def test_start_python_pipeline_with_non_empty_py_requirements_and_without_system is_dataflow_job_id_exist_callback=is_dataflow_job_id_exist_callback, working_directory=None, log=ANY, + periodic_callback=None, ) mock_virtualenv.assert_called_once_with( venv_directory=mock.ANY, @@ -282,6 +285,7 @@ def test_start_java_pipeline(self, mock_runner): working_directory=None, log=ANY, is_dataflow_job_id_exist_callback=None, + periodic_callback=None, ) @mock.patch(BEAM_STRING.format("run_beam_command")) @@ -311,6 +315,7 @@ def test_start_java_pipeline_with_job_class(self, mock_runner): working_directory=None, log=ANY, is_dataflow_job_id_exist_callback=None, + periodic_callback=None, ) @mock.patch(BEAM_STRING.format("shutil.which")) diff --git a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py index 9c21fe8d2becb..086a0c4c92f70 100644 --- a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py @@ -263,6 +263,7 @@ def test_exec_dataflow_runner( py_system_site_packages=False, process_line_callback=mock.ANY, is_dataflow_job_id_exist_callback=mock.ANY, + periodic_callback=mock.ANY, ) @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @@ -281,6 +282,40 @@ def test_exec_dataflow_runner__no_dataflow_job_name( op.execute({}) assert op.dataflow_config.job_name == op.task_id + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_exec_dataflow_runner_periodic_callback_fetches_job_id( + self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock, py_options + ): + """When launcher stdout does not emit a job-ID line, the periodic_callback must poll + the Dataflow API via fetch_job_id_by_name and set dataflow_job_id on the operator.""" + op = BeamRunPythonPipelineOperator( + runner="DataflowRunner", + **self.default_op_kwargs, + ) + dataflow_hook_mock.return_value.fetch_job_id_by_name.return_value = JOB_ID + + captured: dict = {} + + def capture_start(**kwargs): + captured["periodic_callback"] = kwargs.get("periodic_callback") + + beam_hook_mock.return_value.start_python_pipeline.side_effect = capture_start + + op.execute({}) + + periodic_callback = captured.get("periodic_callback") + assert periodic_callback is not None, "periodic_callback was not passed to start_python_pipeline" + + assert op.dataflow_job_id is None + + periodic_callback() + + assert op.dataflow_job_id == JOB_ID + dataflow_hook_mock.return_value.fetch_job_id_by_name.assert_called_once() + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) @@ -486,6 +521,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock job_class=JOB_CLASS, process_line_callback=mock.ANY, is_dataflow_job_id_exist_callback=mock.ANY, + periodic_callback=mock.ANY, ) @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @@ -504,6 +540,41 @@ def test_exec_dataflow_runner__no_dataflow_job_name( op.execute({}) assert op.dataflow_config.job_name == op.task_id + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_exec_dataflow_runner_periodic_callback_fetches_job_id( + self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock + ): + """When launcher stdout does not emit a job-ID line, the periodic_callback must poll + the Dataflow API via fetch_job_id_by_name and set dataflow_job_id on the operator.""" + dataflow_config = DataflowConfiguration(impersonation_chain="test@impersonation.com") + op = BeamRunJavaPipelineOperator( + **self.default_op_kwargs, dataflow_config=dataflow_config, runner="DataflowRunner" + ) + dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False + dataflow_hook_mock.return_value.fetch_job_id_by_name.return_value = JOB_ID + + captured: dict = {} + + def capture_start(**kwargs): + captured["periodic_callback"] = kwargs.get("periodic_callback") + + beam_hook_mock.return_value.start_java_pipeline.side_effect = capture_start + + op.execute({}) + + periodic_callback = captured.get("periodic_callback") + assert periodic_callback is not None, "periodic_callback was not passed to start_java_pipeline" + + assert op.dataflow_job_id is None + + periodic_callback() + + assert op.dataflow_job_id == JOB_ID + dataflow_hook_mock.return_value.fetch_job_id_by_name.assert_called_once() + @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py b/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py index ea1af01c1ad50..72c6b6c1982c3 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py @@ -1299,6 +1299,29 @@ def is_job_done(self, location: str, project_id: str, job_id: str) -> bool: job = job_controller.fetch_job_by_id(job_id) return job_controller.job_reached_terminal_state(job) + @GoogleBaseHook.fallback_to_default_project_id + def fetch_job_id_by_name( + self, + prefix_name: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + ) -> str | None: + """ + Fetch the ID of the Dataflow job whose name matches the given prefix. + + :param prefix_name: Job name prefix to look up. + :param location: Location of the Dataflow job. + :param project_id: Google Cloud project ID in which to look up the job. + :return: Job ID if a matching job is found, else None. + """ + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + jobs = jobs_controller._fetch_jobs_by_prefix_name(prefix_name.lower()) + return jobs[0]["id"] if jobs else None + @GoogleBaseHook.fallback_to_default_project_id def create_data_pipeline( self,