diff --git a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Query.java b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Query.java index ffac399e7262..ae5de38e5b08 100644 --- a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Query.java +++ b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Query.java @@ -318,10 +318,39 @@ public TargetId getTargetId() { return targetId; } + /** + * Returns true if this query identifies a single row that can be served by a point read. Supports + * two shapes: exactly one row key and no row ranges, or exactly one closed-closed row range whose + * start key equals its end key. + */ + @InternalApi + public boolean isSinglePointQuery() { + RowSet rows = this.builder.getRows(); + int keyCount = rows.getRowKeysCount(); + int rangeCount = rows.getRowRangesCount(); + if (keyCount == 1 && rangeCount == 0) { + return true; + } + if (keyCount == 0 && rangeCount == 1) { + RowRange range = rows.getRowRanges(0); + return range.hasStartKeyClosed() + && range.hasEndKeyClosed() + && range.getStartKeyClosed().equals(range.getEndKeyClosed()); + } + return false; + } + @InternalApi public SessionReadRowRequest toSessionPointProto() { + Preconditions.checkState( + isSinglePointQuery(), + "Query must be a single-point read (one row key, or one closed-closed row range whose" + + " start equals its end)"); + RowSet rows = this.builder.getRows(); + ByteString key = + rows.getRowKeysCount() > 0 ? rows.getRowKeys(0) : rows.getRowRanges(0).getStartKeyClosed(); return SessionReadRowRequest.newBuilder() - .setKey(this.builder.getRows().getRowKeysList().get(0)) + .setKey(key) .setFilter(this.builder.getFilter()) .build(); } diff --git a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java index 56b7d634f11a..aa01e0a57f37 100644 --- a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java +++ b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStub.java @@ -27,6 +27,7 @@ import com.google.api.gax.retrying.BasicResultRetryAlgorithm; import com.google.api.gax.retrying.ExponentialRetryAlgorithm; import com.google.api.gax.retrying.RetryAlgorithm; +import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.retrying.RetryingExecutorWithContext; import com.google.api.gax.retrying.ScheduledRetryingExecutor; import com.google.api.gax.retrying.SimpleStreamResumptionStrategy; @@ -37,6 +38,7 @@ import com.google.api.gax.rpc.RequestParamsExtractor; import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.ServerStreamingCallable; +import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.UnaryCallSettings; import com.google.api.gax.rpc.UnaryCallable; import com.google.api.gax.tracing.SpanName; @@ -97,6 +99,7 @@ import com.google.cloud.bigtable.data.v2.stub.mutaterows.MutateRowsRetryingCallable; import com.google.cloud.bigtable.data.v2.stub.readrows.FilterMarkerRowsCallable; import com.google.cloud.bigtable.data.v2.stub.readrows.LargeReadRowsResumptionStrategy; +import com.google.cloud.bigtable.data.v2.stub.readrows.MaybePointReadCallable; import com.google.cloud.bigtable.data.v2.stub.readrows.ReadRowsBatchingDescriptor; import com.google.cloud.bigtable.data.v2.stub.readrows.ReadRowsResumptionStrategy; import com.google.cloud.bigtable.data.v2.stub.readrows.ReadRowsRetryCompletedCallable; @@ -119,6 +122,7 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Function; import javax.annotation.Nonnull; @@ -259,11 +263,20 @@ public ServerStreamingCallable createReadRowsCallable( bigtableClientContext.getClientContext().getTracerFactory(), span); - return traced.withDefaultCallContext( - bigtableClientContext - .getClientContext() - .getDefaultCallContext() - .withRetrySettings(perOpSettings.readRowsSettings.getRetrySettings())); + ServerStreamingCallable classic = + traced.withDefaultCallContext( + bigtableClientContext + .getClientContext() + .getDefaultCallContext() + .withRetrySettings(perOpSettings.readRowsSettings.getRetrySettings())); + + return new MaybePointReadCallable<>( + classic, + createPointReadCallable( + rowAdapter, + "ReadRows", + perOpSettings.readRowsSettings.getRetrySettings(), + perOpSettings.readRowsSettings.getRetryableCodes())); } /** @@ -281,13 +294,25 @@ public ServerStreamingCallable createReadRowsCallable( * */ public UnaryCallable createReadRowCallable(RowAdapter rowAdapter) { + return createPointReadCallable( + rowAdapter, + "ReadRow", + perOpSettings.readRowSettings.getRetrySettings(), + perOpSettings.readRowSettings.getRetryableCodes()); + } + + private UnaryCallable createPointReadCallable( + RowAdapter rowAdapter, + String spanName, + RetrySettings retrySettings, + Set retryableCodes) { ClientContext clientContext = bigtableClientContext.getClientContext(); ServerStreamingCallable readRowsCallable = createReadRowsBaseCallable( ServerStreamingCallSettings.newBuilder() - .setRetryableCodes(perOpSettings.readRowSettings.getRetryableCodes()) - .setRetrySettings(perOpSettings.readRowSettings.getRetrySettings()) + .setRetryableCodes(retryableCodes) + .setRetrySettings(retrySettings) .setIdleTimeoutDuration(Duration.ZERO) .setWaitTimeoutDuration(Duration.ZERO) .build(), @@ -302,16 +327,20 @@ public UnaryCallable createReadRowCallable(RowAdapter BigtableUnaryOperationCallable classic = new BigtableUnaryOperationCallable<>( readRowCallable, - clientContext - .getDefaultCallContext() - .withRetrySettings(perOpSettings.readRowSettings.getRetrySettings()), + clientContext.getDefaultCallContext().withRetrySettings(retrySettings), clientContext.getTracerFactory(), - getSpanName("ReadRow"), + getSpanName(spanName), /* allowNoResponse= */ true); + UnaryCallSettings shimSettings = + perOpSettings.readRowSettings.toBuilder() + .setRetrySettings(retrySettings) + .setRetryableCodes(retryableCodes) + .build(); + return bigtableClientContext .getSessionShim() - .decorateReadRow(classic, rowAdapter, perOpSettings.readRowSettings); + .decorateReadRow(classic, rowAdapter, shimSettings); } private ServerStreamingCallable createReadRowsBaseCallable( diff --git a/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/readrows/MaybePointReadCallable.java b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/readrows/MaybePointReadCallable.java new file mode 100644 index 000000000000..5530cd346612 --- /dev/null +++ b/java-bigtable/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/readrows/MaybePointReadCallable.java @@ -0,0 +1,119 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigtable.data.v2.stub.readrows; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutureCallback; +import com.google.api.core.ApiFutures; +import com.google.api.core.InternalApi; +import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.ResponseObserver; +import com.google.api.gax.rpc.ServerStreamingCallable; +import com.google.api.gax.rpc.StreamController; +import com.google.api.gax.rpc.UnaryCallable; +import com.google.cloud.bigtable.data.v2.models.Query; +import com.google.common.util.concurrent.MoreExecutors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Routes ReadRows calls whose query identifies a single row through a unary point-read callable, + * letting them benefit from the same session-shim diversion as {@code ReadRow}. Queries that cannot + * be reduced to a point read fall through to the classic {@code ReadRows} callable. + */ +@InternalApi +public class MaybePointReadCallable extends ServerStreamingCallable { + private final ServerStreamingCallable classic; + private final UnaryCallable pointReader; + + public MaybePointReadCallable( + ServerStreamingCallable classic, UnaryCallable pointReader) { + this.classic = classic; + this.pointReader = pointReader; + } + + @Override + public void call(Query request, ResponseObserver responseObserver, ApiCallContext context) { + if (!request.isSinglePointQuery()) { + classic.call(request, responseObserver, context); + return; + } + + AtomicBoolean cancelled = new AtomicBoolean(); + AtomicReference> futureRef = new AtomicReference<>(); + + responseObserver.onStart( + new StreamController() { + @Override + public void cancel() { + cancelled.set(true); + ApiFuture f = futureRef.get(); + if (f != null) { + f.cancel(false); + } + } + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int count) {} + }); + + ApiFuture future; + try { + future = pointReader.futureCall(request, context); + } catch (Throwable t) { + if (!cancelled.get()) { + responseObserver.onError(t); + } + return; + } + futureRef.set(future); + if (cancelled.get()) { + future.cancel(false); + } + + ApiFutures.addCallback( + future, + new ApiFutureCallback() { + @Override + public void onSuccess(RowT row) { + if (cancelled.get()) { + return; + } + if (row != null) { + try { + responseObserver.onResponse(row); + } catch (Throwable t) { + responseObserver.onError(t); + return; + } + } + responseObserver.onComplete(); + } + + @Override + public void onFailure(Throwable t) { + if (cancelled.get()) { + return; + } + responseObserver.onError(t); + } + }, + MoreExecutors.directExecutor()); + } +} diff --git a/java-bigtable/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/QueryTest.java b/java-bigtable/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/QueryTest.java index b7c394eb1539..4a2bb337b6b2 100644 --- a/java-bigtable/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/QueryTest.java +++ b/java-bigtable/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/QueryTest.java @@ -23,6 +23,7 @@ import com.google.bigtable.v2.RowFilter; import com.google.bigtable.v2.RowRange; import com.google.bigtable.v2.RowSet; +import com.google.bigtable.v2.SessionReadRowRequest; import com.google.cloud.bigtable.data.v2.internal.ByteStringComparator; import com.google.cloud.bigtable.data.v2.internal.NameUtil; import com.google.cloud.bigtable.data.v2.internal.RequestContext; @@ -966,4 +967,113 @@ public void testQueryReversed() { assertThat(query.toProto(requestContext)) .isEqualTo(expectedReadFromTableProtoBuilder().setReversed(true).build()); } + + @Test + public void isSinglePointQuery_singleRowKey() { + assertThat(Query.create(TABLE_ID).rowKey("k").isSinglePointQuery()).isTrue(); + } + + @Test + public void isSinglePointQuery_singleClosedRange() { + assertThat( + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startClosed("k").endClosed("k")) + .isSinglePointQuery()) + .isTrue(); + } + + @Test + public void isSinglePointQuery_emptyQuery() { + assertThat(Query.create(TABLE_ID).isSinglePointQuery()).isFalse(); + } + + @Test + public void isSinglePointQuery_multipleRowKeys() { + assertThat(Query.create(TABLE_ID).rowKey("a").rowKey("b").isSinglePointQuery()).isFalse(); + } + + @Test + public void isSinglePointQuery_rowKeyAndRange() { + assertThat( + Query.create(TABLE_ID) + .rowKey("a") + .range(ByteStringRange.unbounded().startClosed("a").endClosed("a")) + .isSinglePointQuery()) + .isFalse(); + } + + @Test + public void isSinglePointQuery_multipleRanges() { + assertThat( + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startClosed("a").endClosed("a")) + .range(ByteStringRange.unbounded().startClosed("b").endClosed("b")) + .isSinglePointQuery()) + .isFalse(); + } + + @Test + public void isSinglePointQuery_closedOpenRange() { + assertThat( + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startClosed("k").endOpen("k")) + .isSinglePointQuery()) + .isFalse(); + } + + @Test + public void isSinglePointQuery_unequalClosedRange() { + assertThat( + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startClosed("a").endClosed("b")) + .isSinglePointQuery()) + .isFalse(); + } + + @Test + public void isSinglePointQuery_prefixRange() { + assertThat(Query.create(TABLE_ID).prefix("k").isSinglePointQuery()).isFalse(); + } + + @Test + public void toSessionPointProto_fromRowKey() { + Query query = Query.create(TABLE_ID).rowKey("the-key"); + assertThat(query.toSessionPointProto()) + .isEqualTo( + SessionReadRowRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("the-key")) + .setFilter(RowFilter.getDefaultInstance()) + .build()); + } + + @Test + public void toSessionPointProto_fromClosedRange() { + Query query = + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startClosed("the-key").endClosed("the-key")); + assertThat(query.toSessionPointProto()) + .isEqualTo( + SessionReadRowRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("the-key")) + .setFilter(RowFilter.getDefaultInstance()) + .build()); + } + + @Test + public void toSessionPointProto_preservesFilter() { + RowFilter filter = FILTERS.key().regex("regex").toProto(); + Query query = Query.create(TABLE_ID).rowKey("the-key").filter(FILTERS.key().regex("regex")); + assertThat(query.toSessionPointProto()) + .isEqualTo( + SessionReadRowRequest.newBuilder() + .setKey(ByteString.copyFromUtf8("the-key")) + .setFilter(filter) + .build()); + } + + @Test + public void toSessionPointProto_rejectsNonSinglePointQuery() { + Query query = Query.create(TABLE_ID).rowKey("a").rowKey("b"); + assertThrows(IllegalStateException.class, query::toSessionPointProto); + } } diff --git a/java-bigtable/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/readrows/MaybePointReadCallableTest.java b/java-bigtable/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/readrows/MaybePointReadCallableTest.java new file mode 100644 index 000000000000..42a6c28a9d88 --- /dev/null +++ b/java-bigtable/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/readrows/MaybePointReadCallableTest.java @@ -0,0 +1,226 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigtable.data.v2.stub.readrows; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.api.core.ApiFuture; +import com.google.api.core.SettableApiFuture; +import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.ResponseObserver; +import com.google.api.gax.rpc.StreamController; +import com.google.api.gax.rpc.UnaryCallable; +import com.google.cloud.bigtable.data.v2.models.Query; +import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange; +import com.google.cloud.bigtable.data.v2.models.Row; +import com.google.cloud.bigtable.data.v2.models.RowCell; +import com.google.cloud.bigtable.data.v2.models.TableId; +import com.google.cloud.bigtable.gaxx.testing.FakeStreamingApi.ServerStreamingStashCallable; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class MaybePointReadCallableTest { + + private static final TableId TABLE_ID = TableId.of("fake-table"); + private static final Row ROW_A = + Row.create(ByteString.copyFromUtf8("a"), ImmutableList.of()); + private static final Row ROW_B = + Row.create(ByteString.copyFromUtf8("b"), ImmutableList.of()); + + private ServerStreamingStashCallable classic; + private FakePointReader pointReader; + private MaybePointReadCallable callable; + private RecordingObserver observer; + + @BeforeEach + public void setUp() { + classic = new ServerStreamingStashCallable<>(ImmutableList.of(ROW_A, ROW_B)); + pointReader = new FakePointReader(); + callable = new MaybePointReadCallable<>(classic, pointReader); + observer = new RecordingObserver(); + } + + @Test + public void singleRowKey_routesToPointReader() { + Query query = Query.create(TABLE_ID).rowKey("a"); + + callable.call(query, observer, null); + pointReader.response.set(ROW_A); + + assertThat(pointReader.request).isEqualTo(query); + assertThat(observer.responses).containsExactly(ROW_A); + assertThat(observer.completed).isTrue(); + assertThat(observer.error).isNull(); + assertThat(classic.getActualRequest()).isNull(); + } + + @Test + public void singleClosedRange_routesToPointReader() { + Query query = + Query.create(TABLE_ID).range(ByteStringRange.unbounded().startClosed("a").endClosed("a")); + + callable.call(query, observer, null); + pointReader.response.set(ROW_A); + + assertThat(pointReader.request).isEqualTo(query); + assertThat(observer.responses).containsExactly(ROW_A); + assertThat(observer.completed).isTrue(); + } + + @Test + public void multipleRowKeys_fallsThroughToClassic() { + Query query = Query.create(TABLE_ID).rowKey("a").rowKey("b"); + + callable.call(query, observer, null); + + assertThat(pointReader.request).isNull(); + assertThat(classic.getActualRequest()).isEqualTo(query); + assertThat(observer.responses).containsExactly(ROW_A, ROW_B).inOrder(); + assertThat(observer.completed).isTrue(); + } + + @Test + public void unboundedRange_fallsThroughToClassic() { + Query query = Query.create(TABLE_ID); + + callable.call(query, observer, null); + + assertThat(pointReader.request).isNull(); + assertThat(classic.getActualRequest()).isEqualTo(query); + } + + @Test + public void pointReaderReturnsNull_completesWithoutResponse() { + Query query = Query.create(TABLE_ID).rowKey("missing"); + + callable.call(query, observer, null); + pointReader.response.set(null); + + assertThat(observer.responses).isEmpty(); + assertThat(observer.completed).isTrue(); + assertThat(observer.error).isNull(); + } + + @Test + public void pointReaderFails_propagatesErrorToObserver() { + Query query = Query.create(TABLE_ID).rowKey("a"); + RuntimeException failure = new RuntimeException("boom"); + + callable.call(query, observer, null); + pointReader.response.setException(failure); + + assertThat(observer.responses).isEmpty(); + assertThat(observer.completed).isFalse(); + assertThat(observer.error).isSameInstanceAs(failure); + } + + @Test + public void observerCancel_cancelsFutureAndSuppressesError() { + Query query = Query.create(TABLE_ID).rowKey("a"); + + callable.call(query, observer, null); + observer.controller.cancel(); + + assertThat(pointReader.response.isCancelled()).isTrue(); + assertThat(observer.error).isNull(); + assertThat(observer.completed).isFalse(); + } + + @Test + public void cancelBeforeFutureReturns_cancelsAfterFutureAttaches() { + Query query = Query.create(TABLE_ID).rowKey("a"); + pointReader.onCall = () -> observer.controller.cancel(); + + callable.call(query, observer, null); + + assertThat(pointReader.response.isCancelled()).isTrue(); + assertThat(observer.error).isNull(); + } + + @Test + public void futureCallThrows_routesThroughOnError() { + Query query = Query.create(TABLE_ID).rowKey("a"); + RuntimeException failure = new RuntimeException("sync boom"); + pointReader.syncFailure = failure; + + callable.call(query, observer, null); + + assertThat(observer.controller).isNotNull(); + assertThat(observer.error).isSameInstanceAs(failure); + assertThat(observer.completed).isFalse(); + } + + @Test + public void futureCallThrowsAfterCancel_suppressesError() { + Query query = Query.create(TABLE_ID).rowKey("a"); + pointReader.syncFailure = new RuntimeException("sync boom"); + pointReader.onCall = () -> observer.controller.cancel(); + + callable.call(query, observer, null); + + assertThat(observer.error).isNull(); + } + + private static class FakePointReader extends UnaryCallable { + Query request; + final SettableApiFuture response = SettableApiFuture.create(); + RuntimeException syncFailure; + Runnable onCall; + + @Override + public ApiFuture futureCall(Query request, ApiCallContext context) { + this.request = request; + if (onCall != null) { + onCall.run(); + } + if (syncFailure != null) { + throw syncFailure; + } + return response; + } + } + + private static class RecordingObserver implements ResponseObserver { + final List responses = new ArrayList<>(); + boolean completed; + Throwable error; + StreamController controller; + + @Override + public void onStart(StreamController controller) { + this.controller = controller; + } + + @Override + public void onResponse(Row response) { + responses.add(response); + } + + @Override + public void onError(Throwable t) { + error = t; + } + + @Override + public void onComplete() { + completed = true; + } + } +}