diff --git a/airflow-e2e-tests/tests/airflow_e2e_tests/java_sdk_tests/test_java_sdk_dag.py b/airflow-e2e-tests/tests/airflow_e2e_tests/java_sdk_tests/test_java_sdk_dag.py index 709987cb2c744..6a0da13c479ad 100644 --- a/airflow-e2e-tests/tests/airflow_e2e_tests/java_sdk_tests/test_java_sdk_dag.py +++ b/airflow-e2e-tests/tests/airflow_e2e_tests/java_sdk_tests/test_java_sdk_dag.py @@ -140,6 +140,31 @@ def test_transform_xcom_is_numeric_timestamp(self): f"Expected 'transform' XCom to be a positive integer (millisecond timestamp), got {value!r}" ) + def test_concurrent_client_calls_succeed(self): + """A Java task calling the client from many threads must succeed.""" + resp = self.airflow_client.trigger_dag( + "java_annotation_example", + json={"logical_date": datetime.now(timezone.utc).isoformat()}, + ) + run_id = resp["dag_run_id"] + + dag_state = self.airflow_client.wait_for_dag_run( + dag_id="java_annotation_example", + run_id=run_id, + timeout=_JAVA_TASK_TIMEOUT, + ) + + ti_resp = self.airflow_client.get_task_instances(dag_id="java_annotation_example", run_id=run_id) + ti_map = {ti["task_id"]: ti for ti in ti_resp.get("task_instances", [])} + concurrent_ti = ti_map.get("concurrent", {}) + + assert concurrent_ti.get("state") == "success", ( + f"Java 'concurrent' task did not succeed.\n" + f" task state : {concurrent_ti.get('state')!r}\n" + f" dag state : {dag_state!r}\n" + f" all tasks : { {k: v.get('state') for k, v in ti_map.items()} }" + ) + def test_load_retried_then_succeeded(self): """``load`` fails once (UP_FOR_RETRY) then succeeds on the second attempt. diff --git a/java-sdk/example/src/java/org/apache/airflow/example/AnnotationExample.java b/java-sdk/example/src/java/org/apache/airflow/example/AnnotationExample.java index 85c4162f8caeb..bb715a73cb502 100644 --- a/java-sdk/example/src/java/org/apache/airflow/example/AnnotationExample.java +++ b/java-sdk/example/src/java/org/apache/airflow/example/AnnotationExample.java @@ -21,7 +21,10 @@ import static java.lang.System.Logger.Level.INFO; +import java.util.ArrayList; import java.util.Date; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; import org.apache.airflow.sdk.*; @SuppressWarnings("DuplicatedCode") @@ -72,4 +75,25 @@ public void load(Context context, @Builder.XCom(task = "transform") long transfo } log.log(INFO, "Recovered on retry, try number {0}", context.ti.tryNumber); } + + // Verify one supervisor channel can handle client calls across threads. + @Builder.Task(id = "concurrent") + public void concurrentClientCalls(Client client) throws Exception { + var pool = Executors.newFixedThreadPool(8); + try { + var calls = new ArrayList>(); + for (var i = 0; i < 32; i++) { + calls.add(() -> client.getConnection("test_http").host); + } + for (var future : pool.invokeAll(calls)) { + var host = future.get(); + if (!"example.com".equals(host)) { + throw new RuntimeException("concurrent getConnection returned wrong host: " + host); + } + } + log.log(INFO, "All concurrent client calls returned the correct connection"); + } finally { + pool.shutdown(); + } + } } diff --git a/java-sdk/example/src/resources/dags/java_examples.py b/java-sdk/example/src/resources/dags/java_examples.py index e3d217e8eaa66..38be0e9e947b2 100644 --- a/java-sdk/example/src/resources/dags/java_examples.py +++ b/java-sdk/example/src/resources/dags/java_examples.py @@ -41,6 +41,10 @@ def transform(): ... def load(): ... +@task.stub(queue="java") +def concurrent(): ... + + @task() def python_task_2(transformed): print("python_task_2") @@ -61,6 +65,7 @@ def java_annotation_example(): python_task_1() >> extract() >> transformed python_task_2(transformed) transformed >> load() + concurrent() java_interface_example() diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt index 5d917af3291c8..2263f6770add9 100644 --- a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comm.kt @@ -23,6 +23,8 @@ import io.ktor.utils.io.ByteReadChannel import io.ktor.utils.io.ByteWriteChannel import io.ktor.utils.io.readByteArray import io.ktor.utils.io.writeByteArray +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import org.apache.airflow.sdk.ApiError import org.apache.airflow.sdk.Bundle import org.apache.airflow.sdk.execution.comm.ErrorResponse @@ -56,6 +58,7 @@ class CoordinatorComm( private val nextId = AtomicInt(0) private var shutDownRequested = false + private val commMutex = Mutex() suspend fun startProcessing() { while (!shutDownRequested) { @@ -109,19 +112,22 @@ class CoordinatorComm( } @Throws(ApiError::class) - suspend fun communicateImpl(body: Any): Any { - var frame: IncomingFrame? = null + suspend fun communicateImpl(body: Any): Any = + commMutex.withLock { + val requestId = nextId.fetchAndAdd(1) + var frame: IncomingFrame? = null - suspend fun handle(f: IncomingFrame) { - frame = f - } - sendMessage(nextId.fetchAndAdd(1), body) - processOnce(::handle) - if (frame == null) { - throw ApiError("No response received") + suspend fun handle(f: IncomingFrame) { + frame = f + } + sendMessage(requestId, body) + processOnce(::handle) + val received = frame ?: throw ApiError("No response received") + if (received.id != requestId) { + throw ApiError("response id ${received.id} does not match request id $requestId") + } + received.body ?: Unit } - return frame.body ?: Unit - } @Throws(ApiError::class) suspend inline fun communicate(request: Any): T { diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt index 30a5db40a3389..ead1d1dbe518f 100644 --- a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommTest.kt @@ -19,12 +19,27 @@ package org.apache.airflow.sdk.execution +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.readByteArray +import io.ktor.utils.io.writeByteArray +import kotlinx.coroutines.runBlocking +import org.apache.airflow.sdk.ApiError +import org.apache.airflow.sdk.Bundle +import org.apache.airflow.sdk.execution.comm.GetVariable import org.apache.airflow.sdk.execution.comm.StartupDetails +import org.apache.airflow.sdk.execution.comm.TaskInstance +import org.apache.airflow.sdk.execution.comm.XComResult import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.msgpack.core.MessagePack +import java.io.ByteArrayOutputStream import java.time.OffsetDateTime import java.time.ZoneOffset +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.TimeUnit +import org.apache.airflow.sdk.Client as PublicClient fun byteArrayFromHexString(hexString: String): ByteArray = hexString @@ -83,4 +98,96 @@ class CommsTest { Assertions.assertEquals(expected, actual) } + + private fun responseFrame(id: Int): ByteArray { + val out = ByteArrayOutputStream() + MessagePack.newDefaultPacker(out).use { packer -> + packer.packArrayHeader(3) + packer.packInt(id) + packer.packMapHeader(3) + packer.packString("type") + packer.packString("XComResult") + packer.packString("key") + packer.packString("return_value") + packer.packString("value") + packer.packInt(1) + packer.packNil() + } + return out.toByteArray() + } + + @Test + @DisplayName("Should reject a response whose id does not match the request") + fun rejectsResponseWhoseIdDoesNotMatchRequest() { + val toClient = ByteChannel(autoFlush = true) + val fromClient = ByteChannel(autoFlush = true) + val comm = CoordinatorComm(Bundle(emptyList()), toClient, fromClient) + + val error = + Assertions.assertThrows(ApiError::class.java) { + runBlocking { + // The first request is sent with id 0. The 99 doesn't match 0. + val payload = responseFrame(99) + toClient.writeByteArray(Frame.lengthPrefix(payload.size)) + toClient.writeByteArray(payload) + comm.communicate(GetVariable().also { it.key = "k" }) + } + } + Assertions.assertTrue( + error.message!!.contains("does not match"), + "expected an id-mismatch error, got: ${error.message}", + ) + } + + @Test + @DisplayName("Should stay correlated when the client is called from many threads") + @Timeout(value = 30, unit = TimeUnit.SECONDS) + fun publicClientSurvivesConcurrentThreadCalls() { + val toClient = ByteChannel(autoFlush = true) + val fromClient = ByteChannel(autoFlush = true) + val comm = CoordinatorComm(Bundle(emptyList()), toClient, fromClient) + val details = + StartupDetails().also { + it.ti = + TaskInstance().also { ti -> + ti.dagId = "d" + ti.runId = "r" + } + } + val client = PublicClient(details, CoordinatorClient(comm)) + val n = 50 + + val server = + Thread { + runBlocking { + repeat(n) { + val prefix = fromClient.readByteArray(4) + val payload = fromClient.readByteArray(Frame.parseLengthPrefix(prefix)) + val response = responseFrame(CoordinatorComm.decode(payload).id) + toClient.writeByteArray(Frame.lengthPrefix(response.size)) + toClient.writeByteArray(response) + } + } + } + server.start() + + val errors = ConcurrentLinkedQueue() + val results = ConcurrentLinkedQueue() + val workers = + (1..n).map { + Thread { + try { + results.add(client.getXCom(taskId = "upstream")) + } catch (e: Throwable) { + errors.add(e) + } + } + } + workers.forEach { it.start() } + workers.forEach { it.join() } + server.join() + + Assertions.assertTrue(errors.isEmpty(), "concurrent public-client calls failed: $errors") + Assertions.assertEquals(n, results.size) + } }