Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,31 @@ def test_transform_xcom_is_numeric_timestamp(self):
f"Expected 'transform' XCom to be a positive integer (millisecond timestamp), got {value!r}"
)

def test_concurrent_client_calls_succeed(self):
"""A Java task calling the client from many threads must succeed."""
resp = self.airflow_client.trigger_dag(
"java_annotation_example",
json={"logical_date": datetime.now(timezone.utc).isoformat()},
)
run_id = resp["dag_run_id"]

dag_state = self.airflow_client.wait_for_dag_run(
dag_id="java_annotation_example",
run_id=run_id,
timeout=_JAVA_TASK_TIMEOUT,
)

ti_resp = self.airflow_client.get_task_instances(dag_id="java_annotation_example", run_id=run_id)
ti_map = {ti["task_id"]: ti for ti in ti_resp.get("task_instances", [])}
concurrent_ti = ti_map.get("concurrent", {})

assert concurrent_ti.get("state") == "success", (
f"Java 'concurrent' task did not succeed.\n"
f" task state : {concurrent_ti.get('state')!r}\n"
f" dag state : {dag_state!r}\n"
f" all tasks : { {k: v.get('state') for k, v in ti_map.items()} }"
)

def test_load_retried_then_succeeded(self):
"""``load`` fails once (UP_FOR_RETRY) then succeeds on the second attempt.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@

import static java.lang.System.Logger.Level.INFO;

import java.util.ArrayList;
import java.util.Date;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import org.apache.airflow.sdk.*;

@SuppressWarnings("DuplicatedCode")
Expand Down Expand Up @@ -72,4 +75,25 @@ public void load(Context context, @Builder.XCom(task = "transform") long transfo
}
log.log(INFO, "Recovered on retry, try number {0}", context.ti.tryNumber);
}

// Verify one supervisor channel can handle client calls across threads.
@Builder.Task(id = "concurrent")
public void concurrentClientCalls(Client client) throws Exception {
var pool = Executors.newFixedThreadPool(8);
try {
var calls = new ArrayList<Callable<String>>();
for (var i = 0; i < 32; i++) {
calls.add(() -> client.getConnection("test_http").host);
}
for (var future : pool.invokeAll(calls)) {
var host = future.get();
if (!"example.com".equals(host)) {
throw new RuntimeException("concurrent getConnection returned wrong host: " + host);
}
}
log.log(INFO, "All concurrent client calls returned the correct connection");
} finally {
pool.shutdown();
}
}
}
5 changes: 5 additions & 0 deletions java-sdk/example/src/resources/dags/java_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def transform(): ...
def load(): ...


@task.stub(queue="java")
def concurrent(): ...


@task()
def python_task_2(transformed):
print("python_task_2")
Expand All @@ -61,6 +65,7 @@ def java_annotation_example():
python_task_1() >> extract() >> transformed
python_task_2(transformed)
transformed >> load()
concurrent()


java_interface_example()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.ByteWriteChannel
import io.ktor.utils.io.readByteArray
import io.ktor.utils.io.writeByteArray
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.apache.airflow.sdk.ApiError
import org.apache.airflow.sdk.Bundle
import org.apache.airflow.sdk.execution.comm.ErrorResponse
Expand Down Expand Up @@ -56,6 +58,7 @@ class CoordinatorComm(

private val nextId = AtomicInt(0)
private var shutDownRequested = false
private val commMutex = Mutex()

suspend fun startProcessing() {
while (!shutDownRequested) {
Expand Down Expand Up @@ -109,19 +112,22 @@ class CoordinatorComm(
}

@Throws(ApiError::class)
suspend fun communicateImpl(body: Any): Any {
var frame: IncomingFrame? = null
suspend fun communicateImpl(body: Any): Any =
commMutex.withLock {
val requestId = nextId.fetchAndAdd(1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be outside of the lock since nextId handles concurrency itself?

var frame: IncomingFrame? = null

suspend fun handle(f: IncomingFrame) {
frame = f
}
sendMessage(nextId.fetchAndAdd(1), body)
processOnce(::handle)
if (frame == null) {
throw ApiError("No response received")
suspend fun handle(f: IncomingFrame) {
frame = f
}
sendMessage(requestId, body)
processOnce(::handle)
val received = frame ?: throw ApiError("No response received")
if (received.id != requestId) {
throw ApiError("response id ${received.id} does not match request id $requestId")
}
received.body ?: Unit
Comment on lines +123 to +129

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there’s a way to do though outside of a lock since we have the request ID to identify which response if for which. IIRC the Python implementation does this.

}
return frame.body ?: Unit
}

@Throws(ApiError::class)
suspend inline fun <reified T> communicate(request: Any): T {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,27 @@

package org.apache.airflow.sdk.execution

import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.readByteArray
import io.ktor.utils.io.writeByteArray
import kotlinx.coroutines.runBlocking
import org.apache.airflow.sdk.ApiError
import org.apache.airflow.sdk.Bundle
import org.apache.airflow.sdk.execution.comm.GetVariable
import org.apache.airflow.sdk.execution.comm.StartupDetails
import org.apache.airflow.sdk.execution.comm.TaskInstance
import org.apache.airflow.sdk.execution.comm.XComResult
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.DisplayName
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Timeout
import org.msgpack.core.MessagePack
import java.io.ByteArrayOutputStream
import java.time.OffsetDateTime
import java.time.ZoneOffset
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.TimeUnit
import org.apache.airflow.sdk.Client as PublicClient

fun byteArrayFromHexString(hexString: String): ByteArray =
hexString
Expand Down Expand Up @@ -83,4 +98,96 @@ class CommsTest {

Assertions.assertEquals(expected, actual)
}

private fun responseFrame(id: Int): ByteArray {
val out = ByteArrayOutputStream()
MessagePack.newDefaultPacker(out).use { packer ->
packer.packArrayHeader(3)
packer.packInt(id)
packer.packMapHeader(3)
packer.packString("type")
packer.packString("XComResult")
packer.packString("key")
packer.packString("return_value")
packer.packString("value")
packer.packInt(1)
packer.packNil()
}
return out.toByteArray()
}

@Test
@DisplayName("Should reject a response whose id does not match the request")
fun rejectsResponseWhoseIdDoesNotMatchRequest() {
val toClient = ByteChannel(autoFlush = true)
val fromClient = ByteChannel(autoFlush = true)
val comm = CoordinatorComm(Bundle(emptyList()), toClient, fromClient)

val error =
Assertions.assertThrows(ApiError::class.java) {
runBlocking {
// The first request is sent with id 0. The 99 doesn't match 0.
val payload = responseFrame(99)
toClient.writeByteArray(Frame.lengthPrefix(payload.size))
toClient.writeByteArray(payload)
comm.communicate<XComResult>(GetVariable().also { it.key = "k" })
}
}
Assertions.assertTrue(
error.message!!.contains("does not match"),
"expected an id-mismatch error, got: ${error.message}",
)
}

@Test
@DisplayName("Should stay correlated when the client is called from many threads")
@Timeout(value = 30, unit = TimeUnit.SECONDS)
fun publicClientSurvivesConcurrentThreadCalls() {
val toClient = ByteChannel(autoFlush = true)
val fromClient = ByteChannel(autoFlush = true)
val comm = CoordinatorComm(Bundle(emptyList()), toClient, fromClient)
val details =
StartupDetails().also {
it.ti =
TaskInstance().also { ti ->
ti.dagId = "d"
ti.runId = "r"
}
}
val client = PublicClient(details, CoordinatorClient(comm))
val n = 50

val server =
Thread {
runBlocking {
repeat(n) {
val prefix = fromClient.readByteArray(4)
val payload = fromClient.readByteArray(Frame.parseLengthPrefix(prefix))
val response = responseFrame(CoordinatorComm.decode(payload).id)
toClient.writeByteArray(Frame.lengthPrefix(response.size))
toClient.writeByteArray(response)
}
}
}
server.start()

val errors = ConcurrentLinkedQueue<Throwable>()
val results = ConcurrentLinkedQueue<Any?>()
val workers =
(1..n).map {
Thread {
try {
results.add(client.getXCom(taskId = "upstream"))
} catch (e: Throwable) {
errors.add(e)
}
}
}
workers.forEach { it.start() }
workers.forEach { it.join() }
server.join()

Assertions.assertTrue(errors.isEmpty(), "concurrent public-client calls failed: $errors")
Assertions.assertEquals(n, results.size)
}
}
Loading