diff --git a/kafka-streams-framework/build.gradle.kts b/kafka-streams-framework/build.gradle.kts index dcdbf36..c65e5b8 100644 --- a/kafka-streams-framework/build.gradle.kts +++ b/kafka-streams-framework/build.gradle.kts @@ -33,6 +33,7 @@ dependencies { testImplementation(commonLibs.junit.jupiter) testImplementation(localLibs.junit.pioneer) testImplementation(commonLibs.mockito.core) + testImplementation(commonLibs.mockito.junit) testImplementation(localLibs.hamcrest.core) testRuntimeOnly(commonLibs.log4j.slf4j2.impl) } diff --git a/kafka-streams-framework/gradle.lockfile b/kafka-streams-framework/gradle.lockfile index e1efd02..e5d2e7e 100644 --- a/kafka-streams-framework/gradle.lockfile +++ b/kafka-streams-framework/gradle.lockfile @@ -119,6 +119,7 @@ org.junit:junit-bom:5.10.0=testCompileClasspath org.junit:junit-bom:5.11.2=testRuntimeClasspath org.latencyutils:LatencyUtils:2.0.3=runtimeClasspath,testRuntimeClasspath org.mockito:mockito-core:5.8.0=testCompileClasspath,testRuntimeClasspath +org.mockito:mockito-junit-jupiter:5.8.0=testCompileClasspath,testRuntimeClasspath org.objenesis:objenesis:3.3=testRuntimeClasspath org.opentest4j:opentest4j:1.3.0=testCompileClasspath,testRuntimeClasspath org.projectlombok:lombok:1.18.30=annotationProcessor,compileClasspath,testCompileClasspath diff --git a/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/KafkaStreamsApp.java b/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/KafkaStreamsApp.java index 2e95b3e..e30ba12 100644 --- a/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/KafkaStreamsApp.java +++ b/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/KafkaStreamsApp.java @@ -16,6 +16,7 @@ import static org.apache.kafka.streams.StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG; +import static org.apache.kafka.streams.StreamsConfig.NUM_STREAM_THREADS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.TOPOLOGY_OPTIMIZATION; import static org.apache.kafka.streams.StreamsConfig.consumerPrefix; @@ -34,6 +35,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; import java.util.Properties; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -49,6 +52,7 @@ import org.hypertrace.core.kafkastreams.framework.listeners.LoggingStateListener; import org.hypertrace.core.kafkastreams.framework.listeners.LoggingStateRestoreListener; import org.hypertrace.core.kafkastreams.framework.rocksdb.BoundedMemoryConfigSetter; +import org.hypertrace.core.kafkastreams.framework.threading.StreamThreadsCountResolver; import org.hypertrace.core.kafkastreams.framework.timestampextractors.UseWallclockTimeOnInvalidTimestamp; import org.hypertrace.core.kafkastreams.framework.topics.creator.KafkaTopicCreator; import org.hypertrace.core.kafkastreams.framework.util.ExceptionUtils; @@ -65,6 +69,16 @@ public abstract class KafkaStreamsApp extends PlatformService { public static final String CLEANUP_LOCAL_STATE = "cleanup.local.state"; public static final String PRE_CREATE_TOPICS = "precreate.topics"; public static final String KAFKA_STREAMS_CONFIG_KEY = "kafka.streams.config"; + + /** + * Framework-level boolean opt-in (set in the streams config map) that gates dynamic {@code + * num.stream.threads} resolution. Apps must also override {@link + * #getStreamThreadsCountResolver()}; this flag exists separately so deployments can roll out and + * roll back per-cluster without code changes. The flag is consumed by the framework before the + * config reaches Kafka Streams and is not a Kafka config key. + */ + public static final String DYNAMIC_NUM_STREAM_THREADS_CONFIG = "dynamic.num.stream.threads"; + private static final String SHUTDOWN_DURATION = "shutdown.duration"; private static final Logger logger = LoggerFactory.getLogger(KafkaStreamsApp.class); @@ -106,6 +120,17 @@ protected void doInit() { streamsBuilder = buildTopology(streamsConfig, streamsBuilder, sourceStreams); this.topology = streamsBuilder.build(); + // Strip the dynamic-resolution flag before it reaches Kafka Streams (it's a framework-level + // opt-in, not a Kafka config). When enabled and a resolver is wired up, replace the + // configured num.stream.threads with the dynamically-computed value; otherwise the + // configured value flows through unchanged. + final boolean dynamicEnabled = isDynamicNumStreamThreadsEnabled(streamsConfig); + streamsConfig.remove(DYNAMIC_NUM_STREAM_THREADS_CONFIG); + if (dynamicEnabled) { + resolveDynamicStreamThreads(streamsConfig) + .ifPresent(threads -> streamsConfig.put(NUM_STREAM_THREADS_CONFIG, threads)); + } + getLogger().info("Finalized kafka streams configuration: {}", streamsConfig); // pre-create input/output topics required for kstream application @@ -257,6 +282,80 @@ public abstract StreamsBuilder buildTopology( StreamsBuilder streamsBuilder, Map> sourceStreams); + /** + * Override in subclasses that want auto-sized {@code num.stream.threads}. Return a resolver + * configured with this app's replica-count source. The framework only invokes this when {@link + * #DYNAMIC_NUM_STREAM_THREADS_CONFIG} is {@code true} in the streams config; returning an empty + * optional (the default) disables dynamic resolution. + * + *

Apps that don't override this are unaffected — the default returns {@code Optional.empty()} + * and the framework leaves {@code num.stream.threads} exactly as configured. Apps that do + * override should supply replica count from wherever the deployment exposes it (e.g. the {@code + * REPLICA_COUNT} environment variable injected by the k8s template, or a HOCON config key). + * + *

Example: + * + *

{@code
+   * @Override
+   * protected Optional getStreamThreadsCountResolver() {
+   *   final Optional replicaCount =
+   *       ConfigUtils.optionalInteger(getAppConfig(), "replica.count");
+   *   return Optional.of(
+   *       new StreamThreadsCountResolver(
+   *           StreamThreadsCountResolver.optionalReplicaCount(replicaCount)));
+   * }
+   * }
+ * + *

The configured numeric {@code num.stream.threads} flows through unchanged when: + * + *

+ */ + protected Optional getStreamThreadsCountResolver() { + return Optional.empty(); + } + + // Caller must check the dynamic-enabled flag before invoking. Returns OptionalInt.empty() when + // no resolver is wired or resolution cannot produce a value — caller keeps the configured + // num.stream.threads as the fallback. + private OptionalInt resolveDynamicStreamThreads(Map streamsProperties) { + final Optional resolver; + try { + resolver = getStreamThreadsCountResolver(); + } catch (Exception exception) { + getLogger() + .warn( + "getStreamThreadsCountResolver() threw; keeping configured num.stream.threads", + exception); + return OptionalInt.empty(); + } + if (resolver.isEmpty()) { + getLogger() + .warn( + "{} is true but no StreamThreadsCountResolver is provided; keeping configured num.stream.threads", + DYNAMIC_NUM_STREAM_THREADS_CONFIG); + return OptionalInt.empty(); + } + final OptionalInt resolved = resolver.get().resolve(this.topology, streamsProperties); + if (resolved.isEmpty()) { + getLogger() + .warn("Resolver could not compute dynamic num.stream.threads; keeping configured value"); + } + return resolved; + } + + private static boolean isDynamicNumStreamThreadsEnabled( + final Map streamsProperties) { + final Object value = streamsProperties.get(DYNAMIC_NUM_STREAM_THREADS_CONFIG); + return value != null && Boolean.parseBoolean(String.valueOf(value)); + } + public Map getStreamsConfig(Config jobConfig) { return new HashMap<>(ConfigUtils.getFlatMapConfig(jobConfig, getStreamsConfigKey())); } diff --git a/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/threading/DynamicStreamThreadsCountCalculator.java b/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/threading/DynamicStreamThreadsCountCalculator.java new file mode 100644 index 0000000..50d5471 --- /dev/null +++ b/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/threading/DynamicStreamThreadsCountCalculator.java @@ -0,0 +1,168 @@ +package org.hypertrace.core.kafkastreams.framework.threading; + +import static java.util.stream.Collectors.toUnmodifiableSet; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyDescription; +import org.apache.kafka.streams.TopologyDescription.Source; +import org.apache.kafka.streams.TopologyDescription.Subtopology; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Computes a per-instance {@code num.stream.threads} value from a topology and the partition count + * of every source topic. + * + *

For each sub-topology the maximum partition count across its source topics is the number of + * stream tasks. Summing across sub-topologies and dividing by the replica count yields the threads + * each instance should run to keep all tasks active without idle threads. + * + *

Returns {@link OptionalInt#empty()} when the topology contains a regex/pattern subscription + * ({@link Source#topicPattern()}) — those sub-topologies cannot be enumerated against the broker + * up-front, so dynamic sizing would silently under-count tasks. The caller falls back to its + * configured default in that case. + */ +public class DynamicStreamThreadsCountCalculator { + + private static final long DESCRIBE_TOPICS_TIMEOUT_MILLIS = Duration.ofSeconds(5).toMillis(); + private static final Logger logger = + LoggerFactory.getLogger(DynamicStreamThreadsCountCalculator.class); + + private static Set sourceTopicsOf(final Subtopology subtopology) { + return subtopology.nodes().stream() + .filter(node -> node instanceof Source) + .map(node -> (Source) node) + .flatMap(source -> source.topicSet().stream()) + .collect(toUnmodifiableSet()); + } + + private static boolean hasPatternSource(final Subtopology subtopology) { + return subtopology.nodes().stream() + .filter(node -> node instanceof Source) + .map(node -> (Source) node) + .anyMatch(source -> source.topicPattern() != null); + } + + public OptionalInt compute( + final Topology topology, final AdminClient adminClient, final int replicas) { + if (replicas <= 0) { + throw new IllegalArgumentException("replicas must be positive, got " + replicas); + } + + final TopologyDescription description = topology.describe(); + + // Bail out if any sub-topology subscribes via regex — topicSet() is empty for those, so + // dynamic sizing would silently under-count tasks. The caller substitutes its fallback. + final boolean anyPatternSource = + description.subtopologies().stream() + .anyMatch(DynamicStreamThreadsCountCalculator::hasPatternSource); + if (anyPatternSource) { + logger.warn( + "Topology contains a regex/pattern source; dynamic num.stream.threads is not supported. " + + "Caller will fall back to its configured default."); + return OptionalInt.empty(); + } + + final Set sourceTopics = + description.subtopologies().stream() + .flatMap(subtopology -> sourceTopicsOf(subtopology).stream()) + .collect(toUnmodifiableSet()); + + final Map partitionsByTopic = describePartitions(adminClient, sourceTopics); + + int totalTasks = 0; + int subtopologyCount = 0; + for (final Subtopology subtopology : description.subtopologies()) { + subtopologyCount++; + final Set subtopologyTopics = sourceTopicsOf(subtopology); + + final int tasksForSubtopology = + subtopologyTopics.stream() + .mapToInt(topic -> partitionsByTopic.getOrDefault(topic, 0)) + .max() + .orElse(0); + + if (tasksForSubtopology == 0) { + logger.warn( + "Sub-topology has no resolvable partitions; topics={}. Pod restart will be needed once topics exist.", + subtopologyTopics); + } + totalTasks += tasksForSubtopology; + } + + if (totalTasks == 0) { + logger.warn( + "No resolvable partitions across {} sub-topologies; skipping dynamic num.stream.threads.", + subtopologyCount); + return OptionalInt.empty(); + } + + final int threads = (int) Math.ceil((double) totalTasks / replicas); + logger.info( + "Dynamic num.stream.threads: totalTasks={} across {} sub-topologies, replicas={}, computed={}", + totalTasks, + subtopologyCount, + replicas, + threads); + return OptionalInt.of(threads); + } + + // Single-loop implementation: AdminClient.describeTopics() already fires all RPCs concurrently + // before returning futures, so iteration here only consumes a shared deadline (now+timeout) — + // total wall-clock is capped at DESCRIBE_TOPICS_TIMEOUT_MILLIS regardless of topic count. + private Map describePartitions( + final AdminClient adminClient, final Set topics) { + if (topics.isEmpty()) { + return Map.of(); + } + final DescribeTopicsResult result = adminClient.describeTopics(topics); + final Map> futures = result.topicNameValues(); + final long deadlineMillis = System.currentTimeMillis() + DESCRIBE_TOPICS_TIMEOUT_MILLIS; + final Map partitions = new HashMap<>(); + + for (final Entry> entry : futures.entrySet()) { + final long remainingMillis = deadlineMillis - System.currentTimeMillis(); + if (remainingMillis <= 0) { + throw new RuntimeException( + "Timed out describing topics after " + DESCRIBE_TOPICS_TIMEOUT_MILLIS + "ms"); + } + try { + partitions.put( + entry.getKey(), + entry.getValue().get(remainingMillis, TimeUnit.MILLISECONDS).partitions().size()); + } catch (final TimeoutException timeoutException) { + throw new RuntimeException( + "Timed out describing topic " + entry.getKey(), timeoutException); + } catch (final InterruptedException interruptedException) { + Thread.currentThread().interrupt(); + throw new RuntimeException( + "Interrupted while describing topic " + entry.getKey(), interruptedException); + } catch (final ExecutionException executionException) { + if (executionException.getCause() instanceof UnknownTopicOrPartitionException) { + logger.warn( + "Topic absent on broker: {}. Treating as 0 partitions; restart needed once created.", + entry.getKey()); + partitions.put(entry.getKey(), 0); + } else { + throw new RuntimeException( + "Failed to describe topic " + entry.getKey(), executionException); + } + } + } + return Map.copyOf(partitions); + } +} diff --git a/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/threading/StreamThreadsCountResolver.java b/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/threading/StreamThreadsCountResolver.java new file mode 100644 index 0000000..0daa8cf --- /dev/null +++ b/kafka-streams-framework/src/main/java/org/hypertrace/core/kafkastreams/framework/threading/StreamThreadsCountResolver.java @@ -0,0 +1,88 @@ +package org.hypertrace.core.kafkastreams.framework.threading; + +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Properties; +import java.util.function.Function; +import java.util.function.IntSupplier; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.streams.Topology; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Resolves a concrete {@code num.stream.threads} value for an app that has opted into dynamic + * thread sizing. Bridges the topology + AdminClient to the {@link + * DynamicStreamThreadsCountCalculator} and returns {@link OptionalInt#empty()} when resolution + * cannot produce a meaningful value — the caller keeps the configured numeric {@code + * num.stream.threads}. + * + *

The replica count is supplied externally — applications typically wire it from a {@code + * REPLICA_COUNT} environment variable injected by the deployment template. A non-positive value + * (zero, negative, or unset) yields {@link OptionalInt#empty()}. + */ +public class StreamThreadsCountResolver { + + private static final Logger logger = LoggerFactory.getLogger(StreamThreadsCountResolver.class); + + private final DynamicStreamThreadsCountCalculator calculator; + private final IntSupplier replicaCountSupplier; + private final Function adminClientFactory; + + public StreamThreadsCountResolver(final IntSupplier replicaCountSupplier) { + this(new DynamicStreamThreadsCountCalculator(), replicaCountSupplier, AdminClient::create); + } + + public StreamThreadsCountResolver( + final DynamicStreamThreadsCountCalculator calculator, + final IntSupplier replicaCountSupplier, + final Function adminClientFactory) { + this.calculator = calculator; + this.replicaCountSupplier = replicaCountSupplier; + this.adminClientFactory = adminClientFactory; + } + + /** + * Resolve a thread count for the given topology. Returns {@link OptionalInt#empty()} when + * prerequisites are missing (non-positive replica count), the topology is unsupported (regex + * sources), or the AdminClient/calculator call fails — the caller keeps the configured numeric + * {@code num.stream.threads} so the application can still start with a sane value. + */ + public OptionalInt resolve(final Topology topology, final Map streamsProperties) { + final int replicas = replicaCountSupplier.getAsInt(); + if (replicas <= 0) { + logger.warn("replica.count is non-positive ({}); skipping dynamic resolution", replicas); + return OptionalInt.empty(); + } + try (final AdminClient adminClient = + adminClientFactory.apply(toProperties(streamsProperties))) { + return calculator.compute(topology, adminClient, replicas); + } catch (final RuntimeException runtimeException) { + logger.error( + "Failed to compute dynamic num.stream.threads; skipping dynamic resolution", + runtimeException); + return OptionalInt.empty(); + } + } + + /** + * Convenience adapter — given an {@code Optional} replica-count source (the common + * config-loaded shape), produce a supplier that yields the value when present and {@code 0} when + * absent (treated as a "skip dynamic resolution" signal by {@link #resolve}). + */ + public static IntSupplier optionalReplicaCount(final Optional replicaCount) { + return () -> replicaCount.orElse(0); + } + + private static Properties toProperties(final Map streamsProperties) { + final Properties properties = new Properties(); + streamsProperties.forEach( + (key, value) -> { + if (value != null) { + properties.put(key, value); + } + }); + return properties; + } +} diff --git a/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/SampleAppTest.java b/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/SampleAppTest.java index fc4920e..80e1429 100644 --- a/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/SampleAppTest.java +++ b/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/SampleAppTest.java @@ -4,21 +4,29 @@ import static org.apache.kafka.streams.StreamsConfig.DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG; +import static org.apache.kafka.streams.StreamsConfig.NUM_STREAM_THREADS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG; import static org.apache.kafka.streams.StreamsConfig.producerPrefix; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import com.typesafe.config.Config; import io.confluent.kafka.streams.serdes.avro.SpecificAvroSerde; import java.util.Map; +import java.util.Optional; import java.util.Properties; +import java.util.regex.Pattern; import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StreamsBuilder; import org.apache.kafka.streams.TestInputTopic; import org.apache.kafka.streams.TestOutputTopic; import org.apache.kafka.streams.TopologyTestDriver; import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; import org.hypertrace.core.kafkastreams.framework.rocksdb.BoundedMemoryConfigSetter; +import org.hypertrace.core.kafkastreams.framework.threading.StreamThreadsCountResolver; import org.hypertrace.core.serviceframework.config.ConfigClientFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -84,4 +92,91 @@ public void baseStreamsConfigTest() { is(LogAndContinueExceptionHandler.class)); assertThat(baseStreamsConfig.get(producerPrefix(ACKS_CONFIG)), is("all")); } + + // No resolver wired up → framework must keep the configured num.stream.threads (the configured + // value is the fallback by definition) and strip the framework-only flag before it reaches + // Kafka Streams. + @Test + public void dynamicWithoutResolverKeepsConfiguredValue() { + final Object configuredThreads = sampleApp.streamsConfig.get(NUM_STREAM_THREADS_CONFIG); + + SampleApp dynamicApp = + new SampleApp(ConfigClientFactory.getClient()) { + @Override + public Map getStreamsConfig(Config jobConfig) { + Map properties = super.getStreamsConfig(jobConfig); + properties.put(KafkaStreamsApp.DYNAMIC_NUM_STREAM_THREADS_CONFIG, true); + return properties; + } + }; + + dynamicApp.doInit(); + + assertThat(dynamicApp.streamsConfig.get(NUM_STREAM_THREADS_CONFIG), is(configuredThreads)); + assertThat( + dynamicApp.streamsConfig.containsKey(KafkaStreamsApp.DYNAMIC_NUM_STREAM_THREADS_CONFIG), + is(false)); + } + + // Pattern-source topology: calculator returns OptionalInt.empty(). Configured num.stream.threads + // flows through unchanged. + @Test + public void dynamicWithPatternSourceKeepsConfiguredValue() { + final Object configuredThreads = sampleApp.streamsConfig.get(NUM_STREAM_THREADS_CONFIG); + + SampleApp dynamicApp = + new SampleApp(ConfigClientFactory.getClient()) { + @Override + public Map getStreamsConfig(Config jobConfig) { + Map properties = super.getStreamsConfig(jobConfig); + properties.put(KafkaStreamsApp.DYNAMIC_NUM_STREAM_THREADS_CONFIG, true); + return properties; + } + + @Override + public StreamsBuilder buildTopology( + Map streamsConfig, + StreamsBuilder streamsBuilder, + Map> sourceStreams) { + streamsBuilder.stream( + Pattern.compile("input-.*"), Consumed.with(Serdes.String(), Serdes.String())) + .foreach((key, value) -> {}); + return streamsBuilder; + } + + @Override + protected Optional getStreamThreadsCountResolver() { + return Optional.of(new StreamThreadsCountResolver(() -> 8)); + } + }; + + dynamicApp.doInit(); + + assertThat(dynamicApp.streamsConfig.get(NUM_STREAM_THREADS_CONFIG), is(configuredThreads)); + } + + // Resolver throws → configured num.stream.threads flows through unchanged. + @Test + public void dynamicWithThrowingResolverKeepsConfiguredValue() { + final Object configuredThreads = sampleApp.streamsConfig.get(NUM_STREAM_THREADS_CONFIG); + + SampleApp dynamicApp = + new SampleApp(ConfigClientFactory.getClient()) { + @Override + public Map getStreamsConfig(Config jobConfig) { + Map properties = super.getStreamsConfig(jobConfig); + properties.put(KafkaStreamsApp.DYNAMIC_NUM_STREAM_THREADS_CONFIG, true); + return properties; + } + + @Override + protected Optional getStreamThreadsCountResolver() { + throw new RuntimeException("simulated wiring failure"); + } + }; + + dynamicApp.doInit(); + + assertThat(dynamicApp.streamsConfig.get(NUM_STREAM_THREADS_CONFIG), is(configuredThreads)); + } } diff --git a/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/threading/DynamicStreamThreadsCountCalculatorTest.java b/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/threading/DynamicStreamThreadsCountCalculatorTest.java new file mode 100644 index 0000000..5f82245 --- /dev/null +++ b/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/threading/DynamicStreamThreadsCountCalculatorTest.java @@ -0,0 +1,133 @@ +package org.hypertrace.core.kafkastreams.framework.threading; + +import static java.util.stream.Collectors.toUnmodifiableList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anySet; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.IntStream; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class DynamicStreamThreadsCountCalculatorTest { + + private static final int ABSENT = -1; + + private final DynamicStreamThreadsCountCalculator calculator = + new DynamicStreamThreadsCountCalculator(); + + @Mock private AdminClient adminClient; + + @Test + void singleSubtopologyWithThirtyPartitionsAndEightReplicasGivesFourThreads() { + final Topology topology = topologyForSubtopologies(Set.of("topic-a")); + stubPartitions(Map.of("topic-a", 30)); + + assertEquals(OptionalInt.of(4), calculator.compute(topology, adminClient, 8)); + } + + @Test + void twoSubtopologiesAreSummedThenDividedByReplicas() { + final Topology topology = topologyForSubtopologies(Set.of("topic-a"), Set.of("topic-b")); + stubPartitions(Map.of("topic-a", 10, "topic-b", 20)); + + assertEquals(OptionalInt.of(4), calculator.compute(topology, adminClient, 8)); + } + + @Test + void absentTopicCountsAsZeroPartitions() { + final Topology topology = topologyForSubtopologies(Set.of("topic-a"), Set.of("topic-missing")); + stubPartitions(Map.of("topic-a", 16, "topic-missing", ABSENT)); + + assertEquals(OptionalInt.of(2), calculator.compute(topology, adminClient, 8)); + } + + @Test + void zeroOrNegativeReplicasThrows() { + final Topology topology = topologyForSubtopologies(Set.of("topic-a")); + + assertThrows( + IllegalArgumentException.class, () -> calculator.compute(topology, adminClient, 0)); + assertThrows( + IllegalArgumentException.class, () -> calculator.compute(topology, adminClient, -1)); + } + + // No resolvable partitions -> caller keeps the configured num.stream.threads as-is. + @Test + void totalTasksZeroReturnsEmpty() { + final Topology topology = topologyForSubtopologies(Set.of("topic-a")); + stubPartitions(Map.of("topic-a", ABSENT)); + + assertEquals(OptionalInt.empty(), calculator.compute(topology, adminClient, 8)); + } + + // Regex/pattern subscriptions cannot be enumerated up-front against the broker, so dynamic + // sizing would silently under-count tasks. Calculator must signal "not applicable" via empty + // so the caller falls back to its configured default. + @Test + void patternSubscriptionReturnsEmpty() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream(Pattern.compile("topic-.*"), Consumed.with(Serdes.String(), Serdes.String())) + .foreach((key, value) -> {}); + final Topology topology = builder.build(); + + assertEquals(OptionalInt.empty(), calculator.compute(topology, adminClient, 8)); + } + + @SafeVarargs + private static Topology topologyForSubtopologies(final Set... topicsBySubtopology) { + final StreamsBuilder builder = new StreamsBuilder(); + for (final Set topics : topicsBySubtopology) { + for (final String topic : topics) { + builder.stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .foreach((key, value) -> {}); + } + } + return builder.build(); + } + + private void stubPartitions(final Map partitionsByTopic) { + final DescribeTopicsResult result = mock(DescribeTopicsResult.class); + final Map> futures = new HashMap<>(); + partitionsByTopic.forEach( + (topic, partitionCount) -> futures.put(topic, futureFor(topic, partitionCount))); + when(result.topicNameValues()).thenReturn(futures); + when(adminClient.describeTopics(anySet())).thenReturn(result); + } + + private static KafkaFuture futureFor(final String topic, final int partitions) { + if (partitions == ABSENT) { + final KafkaFutureImpl failed = new KafkaFutureImpl<>(); + failed.completeExceptionally(new UnknownTopicOrPartitionException(topic)); + return failed; + } + final List partitionInfos = + IntStream.range(0, partitions) + .mapToObj(index -> new TopicPartitionInfo(index, Node.noNode(), List.of(), List.of())) + .collect(toUnmodifiableList()); + return KafkaFuture.completedFuture(new TopicDescription(topic, false, partitionInfos)); + } +} diff --git a/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/threading/StreamThreadsCountResolverTest.java b/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/threading/StreamThreadsCountResolverTest.java new file mode 100644 index 0000000..a546e75 --- /dev/null +++ b/kafka-streams-framework/src/test/java/org/hypertrace/core/kafkastreams/framework/threading/StreamThreadsCountResolverTest.java @@ -0,0 +1,111 @@ +package org.hypertrace.core.kafkastreams.framework.threading; + +import static java.util.stream.Collectors.toUnmodifiableList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.anySet; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Properties; +import java.util.function.Function; +import java.util.function.IntSupplier; +import java.util.stream.IntStream; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.junit.jupiter.api.Test; + +class StreamThreadsCountResolverTest { + + @Test + void calculatorThrowReturnsEmpty() { + final AdminClient adminClient = mock(AdminClient.class); + when(adminClient.describeTopics(anySet())) + .thenThrow(new RuntimeException("simulated broker outage")); + final StreamThreadsCountResolver resolver = resolverWith(() -> 8, adminClient); + + assertEquals(OptionalInt.empty(), resolver.resolve(simpleTopology(), bootstrapProperties())); + } + + @Test + void missingReplicaCountReturnsEmpty() { + final IntSupplier missing = StreamThreadsCountResolver.optionalReplicaCount(Optional.empty()); + final StreamThreadsCountResolver resolver = resolverWith(missing, mock(AdminClient.class)); + + assertEquals(OptionalInt.empty(), resolver.resolve(simpleTopology(), bootstrapProperties())); + } + + @Test + void zeroReplicaCountReturnsEmpty() { + final StreamThreadsCountResolver resolver = resolverWith(() -> 0, mock(AdminClient.class)); + + assertEquals(OptionalInt.empty(), resolver.resolve(simpleTopology(), bootstrapProperties())); + } + + @Test + void negativeReplicaCountReturnsEmpty() { + final StreamThreadsCountResolver resolver = resolverWith(() -> -1, mock(AdminClient.class)); + + assertEquals(OptionalInt.empty(), resolver.resolve(simpleTopology(), bootstrapProperties())); + } + + @Test + void delegatesToCalculatorWithConfiguredReplicas() { + // 30-partition source / 8 replicas -> ceil(30/8) = 4 threads. Real calculator computes this; + // only AdminClient (the external) is stubbed. + final AdminClient adminClient = mock(AdminClient.class); + stubPartitions(adminClient, Map.of("topic-a", 30)); + final StreamThreadsCountResolver resolver = resolverWith(() -> 8, adminClient); + + assertEquals(OptionalInt.of(4), resolver.resolve(simpleTopology(), bootstrapProperties())); + } + + private static StreamThreadsCountResolver resolverWith( + final IntSupplier replicaCountSupplier, final AdminClient adminClient) { + final Function factory = properties -> adminClient; + return new StreamThreadsCountResolver( + new DynamicStreamThreadsCountCalculator(), replicaCountSupplier, factory); + } + + private static Topology simpleTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("topic-a", Consumed.with(Serdes.String(), Serdes.String())) + .foreach((key, value) -> {}); + return builder.build(); + } + + private static Map bootstrapProperties() { + return Map.of(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + } + + private static void stubPartitions( + final AdminClient adminClient, final Map partitionsByTopic) { + final DescribeTopicsResult result = mock(DescribeTopicsResult.class); + final Map> futures = new HashMap<>(); + partitionsByTopic.forEach( + (topic, partitionCount) -> futures.put(topic, futureFor(topic, partitionCount))); + when(result.topicNameValues()).thenReturn(futures); + when(adminClient.describeTopics(anySet())).thenReturn(result); + } + + private static KafkaFuture futureFor(final String topic, final int partitions) { + final List partitionInfos = + IntStream.range(0, partitions) + .mapToObj(index -> new TopicPartitionInfo(index, Node.noNode(), List.of(), List.of())) + .collect(toUnmodifiableList()); + return KafkaFuture.completedFuture(new TopicDescription(topic, false, partitionInfos)); + } +}