diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/ReadStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/ReadStreamManagement.java index 5336760a0b..81bc55b611 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/ReadStreamManagement.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/ReadStreamManagement.java @@ -17,6 +17,7 @@ */ package org.apache.ratis.netty.server; +import org.apache.ratis.client.impl.OrderedAsync; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer; import org.apache.ratis.datastream.impl.DataStreamRequestByteBuf; @@ -47,6 +48,7 @@ import static org.apache.ratis.client.impl.ClientProtoUtils.toRaftClientRequest; import static org.apache.ratis.client.impl.ClientProtoUtils.toRaftClientReplyProto; +import static org.apache.ratis.netty.server.DataStreamManagement.newDataStreamReplyByteBuffer; import static org.apache.ratis.netty.server.DataStreamManagement.replyDataStreamException; public class ReadStreamManagement { @@ -60,24 +62,12 @@ static class ReadStream implements WritableByteChannel { private final DataStreamReplyByteBuffer terminalReply; private long streamOffset; - ReadStream(RaftClientRequest request, long streamId, ChannelHandlerContext ctx) { + ReadStream(RaftClientRequest request, long streamId, ChannelHandlerContext ctx, RaftClientReply terminalReply) { this.clientId = request.getClientId(); this.streamId = streamId; this.ctx = ctx; - final RaftClientReply reply = RaftClientReply.newBuilder() - .setRequest(request) - .setSuccess() - .build(); - this.terminalReply = DataStreamReplyByteBuffer.newBuilder() - .setClientId(clientId) - .setType(Type.STREAM_HEADER) - .setStreamId(streamId) - .setStreamOffset(0) - .setBuffer(toRaftClientReplyProto(reply).toByteString().asReadOnlyByteBuffer()) - .setSuccess(true) - .setBytesWritten(0) - .build(); + this.terminalReply = newReadStreamTerminalReply(clientId, streamId, terminalReply); } @Override @@ -186,17 +176,80 @@ private boolean processImpl(DataStreamRequestByteBuf requestBuf, ChannelHandlerC return true; } - final ReadStream stream = new ReadStream(request, requestBuf.getStreamId(), ctx); - requestExecutor.execute(() -> { + final CompletableFuture readOnlyCheck; + try { + readOnlyCheck = server.submitClientRequestAsync(newDummyReadRequest(request)); + } catch (IOException e) { + replyDataStreamException(server, e, request, requestBuf, ctx); + return true; + } + + readOnlyCheck.whenCompleteAsync((reply, exception) -> { + if (exception != null) { + replyDataStreamException(server, exception, request, requestBuf, ctx); + return; + } + + final RaftClientReply terminalReply = toReadStreamReply(request, reply); + if (!reply.isSuccess()) { + ctx.writeAndFlush(newDataStreamReplyByteBuffer(requestBuf, terminalReply)); + return; + } + + final ReadStream stream = new ReadStream(request, requestBuf.getStreamId(), ctx, terminalReply); try { division.getStateMachine().data().query(request.getMessage(), stream); } catch (Throwable t) { LOG.error("{}: Failed read-only data stream query for {}", this, request, t); } - }); + }, requestExecutor); return true; } + private static RaftClientRequest newDummyReadRequest(RaftClientRequest request) { + final RaftClientRequest.Builder builder = RaftClientRequest.newBuilder() + .setClientId(request.getClientId()) + .setGroupId(request.getRaftGroupId()) + .setCallId(request.getCallId()) + .setMessage(OrderedAsync.DUMMY) + .setType(request.getType()) + .setRepliedCallIds(request.getRepliedCallIds()) + .setSlidingWindowEntry(request.getSlidingWindowEntry()) + .setRoutingTable(request.getRoutingTable()) + .setTimeoutMs(request.getTimeoutMs()) + .setSpanContext(request.getSpanContext()); + if (request.isToLeader()) { + builder.setLeaderId(request.getServerId()); + } else { + builder.setServerId(request.getServerId()); + } + return builder.build(); + } + + private static RaftClientReply toReadStreamReply(RaftClientRequest request, RaftClientReply reply) { + return RaftClientReply.newBuilder() + .setRequest(request) + .setSuccess(reply.isSuccess()) + .setException(reply.getException()) + .setLogIndex(reply.getLogIndex()) + .setCommitInfos(reply.getCommitInfos()) + .build(); + } + + private static DataStreamReplyByteBuffer newReadStreamTerminalReply( + ClientId clientId, long streamId, RaftClientReply reply) { + return DataStreamReplyByteBuffer.newBuilder() + .setClientId(clientId) + .setType(Type.STREAM_HEADER) + .setStreamId(streamId) + .setStreamOffset(0) + .setBuffer(toRaftClientReplyProto(reply).toByteString().asReadOnlyByteBuffer()) + .setSuccess(reply.isSuccess()) + .setBytesWritten(0) + .setCommitInfos(reply.getCommitInfos()) + .build(); + } + @Override public String toString() { return name; diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java index f758fd0edc..84a39ba1c0 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java @@ -1113,11 +1113,12 @@ private CompletableFuture readAsync(RaftClientRequest request) if (request.getType().getRead().getPreferNonLinearizable() || readOption == RaftServerConfigKeys.Read.Option.DEFAULT) { final CompletableFuture reply = checkLeaderState(request); - if (reply != null) { - return reply; - } - return queryStateMachine(request); - } else if (readOption == RaftServerConfigKeys.Read.Option.LINEARIZABLE){ + if (reply != null) { + return reply; + } + return isDummyRead(request) ? CompletableFuture.completedFuture(newSuccessReply(request)) + : queryStateMachine(request); + } else if (readOption == RaftServerConfigKeys.Read.Option.LINEARIZABLE) { final LeaderStateImpl leader = role.getLeaderState().orElse(null); final CompletableFuture replyFuture; if (leader != null) { @@ -1136,12 +1137,17 @@ private CompletableFuture readAsync(RaftClientRequest request) return replyFuture .thenCompose(readIndex -> getState().getReadRequests().waitToAdvance(readIndex, () -> getReadException("add", snapshotInstallationHandler.getInProgressInstallSnapshotIndex(), false))) - .thenCompose(readIndex -> queryStateMachine(request)) + .thenCompose(readIndex -> isDummyRead(request) + ? CompletableFuture.completedFuture(newSuccessReply(request)) : queryStateMachine(request)) .exceptionally(e -> readException2Reply(request, e)); } else { throw new IllegalStateException("Unexpected read option: " + readOption); } } + private static boolean isDummyRead(RaftClientRequest request) { + return request.getMessage() != null && OrderedAsync.DUMMY.getContent().equals(request.getMessage().getContent()); + } + private RaftClientReply readException2Reply(RaftClientRequest request, Throwable e) { e = JavaUtils.unwrapCompletionException(e); if (e instanceof StateMachineException ) { diff --git a/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java b/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java index 188f119fca..56ecd166bc 100644 --- a/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java +++ b/ratis-test/src/test/java/org/apache/ratis/netty/server/TestDataStreamManagement.java @@ -19,6 +19,7 @@ import org.apache.ratis.client.impl.ClientProtoUtils; import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl; +import org.apache.ratis.client.impl.OrderedAsync; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer; import org.apache.ratis.datastream.impl.DataStreamRequestByteBuf; @@ -28,10 +29,12 @@ import org.apache.ratis.protocol.ClientId; import org.apache.ratis.protocol.DataStreamReply; import org.apache.ratis.protocol.Message; +import org.apache.ratis.protocol.RaftClientReply; import org.apache.ratis.protocol.RaftClientRequest; import org.apache.ratis.protocol.RaftGroupId; import org.apache.ratis.protocol.RaftPeer; import org.apache.ratis.protocol.RaftPeerId; +import org.apache.ratis.protocol.exceptions.ReadIndexException; import org.apache.ratis.server.RaftServer; import org.apache.ratis.statemachine.StateMachine; import org.apache.ratis.statemachine.StateMachine.DataApi; @@ -59,14 +62,17 @@ import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; class TestDataStreamManagement { @@ -96,8 +102,9 @@ public void query(Message request, WritableByteChannel stream) { assertTrue(management.process(readOnlyRequest.request, embeddedChannel.pipeline().firstContext())); assertEquals(0, readOnlyRequest.headerBuf.refCnt()); + JavaUtils.attempt(() -> assertNotNull(streamRef.get()), 10, + TimeDuration.valueOf(100, TimeUnit.MILLISECONDS), "read-only stream", null); final WritableByteChannel stream = streamRef.get(); - assertNotNull(stream); stream.write(response.asReadOnlyByteBuffer()); stream.close(); @@ -116,6 +123,99 @@ public void query(Message request, WritableByteChannel stream) { assertTrue(ClientProtoUtils.getRaftClientReply(replies.get(1)).isSuccess()); } finally { embeddedChannel.finishAndReleaseAll(); + management.shutdown(); + } + } + + @Test + void readOnlyRequestWaitsForLinearizableCheck() throws Exception { + final RaftPeerId serverId = RaftPeerId.valueOf("s1"); + final ClientId clientId = ClientId.randomId(); + final RaftGroupId groupId = RaftGroupId.randomId(); + final ByteString query = ByteString.copyFromUtf8("query"); + final CompletableFuture readOnlyCheck = new CompletableFuture<>(); + final AtomicReference submittedReadOnlyCheck = new AtomicReference<>(); + final AtomicReference messageRef = new AtomicReference<>(); + final AtomicReference streamRef = new AtomicReference<>(); + + final DataApi dataApi = new DataApi() { + @Override + public void query(Message request, WritableByteChannel stream) { + messageRef.set(request); + streamRef.set(stream); + } + }; + final ReadStreamManagement management = newReadStreamManagement(serverId, groupId, dataApi, request -> { + submittedReadOnlyCheck.set(request); + return readOnlyCheck; + }); + final EmbeddedChannel embeddedChannel = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + final ReadOnlyRequest readOnlyRequest = newReadOnlyRequest(clientId, serverId, groupId, 1L, query); + + try { + assertTrue(management.process(readOnlyRequest.request, embeddedChannel.pipeline().firstContext())); + assertEquals(0, readOnlyRequest.headerBuf.refCnt()); + + final RaftClientRequest checkRequest = submittedReadOnlyCheck.get(); + assertNotNull(checkRequest); + assertEquals(OrderedAsync.DUMMY.getContent(), checkRequest.getMessage().getContent()); + assertNull(streamRef.get(), "state machine query should wait for the read-only check"); + + readOnlyCheck.complete(RaftClientReply.newBuilder().setRequest(checkRequest).setSuccess().build()); + JavaUtils.attempt(() -> assertNotNull(streamRef.get()), 10, + TimeDuration.valueOf(100, TimeUnit.MILLISECONDS), "linearizable read-only stream", null); + assertEquals(query, messageRef.get().getContent()); + } finally { + embeddedChannel.finishAndReleaseAll(); + management.shutdown(); + } + } + + @Test + void readOnlyCheckFailureSkipsStateMachineQuery() throws Exception { + final RaftPeerId serverId = RaftPeerId.valueOf("s1"); + final ClientId clientId = ClientId.randomId(); + final RaftGroupId groupId = RaftGroupId.randomId(); + final ByteString query = ByteString.copyFromUtf8("query"); + final AtomicBoolean queryCalled = new AtomicBoolean(); + + final DataApi dataApi = new DataApi() { + @Override + public void query(Message request, WritableByteChannel stream) { + queryCalled.set(true); + } + }; + final ReadStreamManagement management = newReadStreamManagement(serverId, groupId, dataApi, request -> + CompletableFuture.completedFuture(RaftClientReply.newBuilder() + .setRequest(request) + .setException(new ReadIndexException("read index failed")) + .build())); + final EmbeddedChannel embeddedChannel = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + final ReadOnlyRequest readOnlyRequest = newReadOnlyRequest(clientId, serverId, groupId, 1L, query); + + try { + assertTrue(management.process(readOnlyRequest.request, embeddedChannel.pipeline().firstContext())); + assertEquals(0, readOnlyRequest.headerBuf.refCnt()); + + final List replies = new ArrayList<>(); + JavaUtils.attempt(() -> { + for (Object outbound; (outbound = embeddedChannel.readOutbound()) != null;) { + replies.add((DataStreamReply) outbound); + } + assertEquals(1, replies.size()); + }, 10, TimeDuration.valueOf(100, TimeUnit.MILLISECONDS), "read-only check failure reply", null); + + assertFalse(queryCalled.get(), "state machine query should not run when the read-only check fails"); + final DataStreamReply reply = replies.get(0); + assertEquals(Type.STREAM_HEADER, reply.getType()); + assertFalse(reply.isSuccess()); + final RaftClientReply clientReply = ClientProtoUtils.getRaftClientReply(reply); + assertFalse(clientReply.isSuccess()); + assertNotNull(clientReply.getReadIndexException()); + assertEquals(serverId, clientReply.getServerId()); + } finally { + embeddedChannel.finishAndReleaseAll(); + management.shutdown(); } } @@ -231,13 +331,19 @@ private static class ReadOnlyRequest { private static ReadStreamManagement newReadStreamManagement( RaftPeerId serverId, RaftGroupId groupId, DataApi dataApi) { + return newReadStreamManagement(serverId, groupId, dataApi, TestDataStreamManagement::successReply); + } + + private static ReadStreamManagement newReadStreamManagement(RaftPeerId serverId, RaftGroupId groupId, + DataApi dataApi, Function> submitClientRequestAsync) { final StateMachine stateMachine = new BaseStateMachine() { @Override public DataApi data() { return dataApi; } }; - final RaftServer server = newRaftServer(serverId, new RaftProperties(), groupId, newDivision(stateMachine)); + final RaftServer server = newRaftServer(serverId, new RaftProperties(), groupId, newDivision(stateMachine), + submitClientRequestAsync); return new ReadStreamManagement(server); } @@ -275,6 +381,16 @@ private static RaftServer newRaftServer(RaftPeerId serverId, RaftProperties prop private static RaftServer newRaftServer(RaftPeerId serverId, RaftProperties properties, RaftGroupId groupId, RaftServer.Division division) { + return newRaftServer(serverId, properties, groupId, division, TestDataStreamManagement::successReply); + } + + private static CompletableFuture successReply(RaftClientRequest request) { + return CompletableFuture.completedFuture(RaftClientReply.newBuilder().setRequest(request).setSuccess().build()); + } + + private static RaftServer newRaftServer(RaftPeerId serverId, RaftProperties properties, + RaftGroupId groupId, RaftServer.Division division, + Function> submitClientRequestAsync) { return (RaftServer) Proxy.newProxyInstance(RaftServer.class.getClassLoader(), new Class[]{RaftServer.class}, (proxy, method, args) -> { switch (method.getName()) { @@ -287,6 +403,8 @@ private static RaftServer newRaftServer(RaftPeerId serverId, RaftProperties prop return division; } throw new IOException("Division not found: " + args[0]); + case "submitClientRequestAsync": + return submitClientRequestAsync.apply((RaftClientRequest) args[0]); case "close": return null; case "toString":