diff --git a/examples/features/trpc_stream/client/client_shopping.cc b/examples/features/trpc_stream/client/client_shopping.cc new file mode 100644 index 00000000..ec4fc732 --- /dev/null +++ b/examples/features/trpc_stream/client/client_shopping.cc @@ -0,0 +1,334 @@ +// +// +// Tencent is pleased to support the open source community by making tRPC available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. +// All rights reserved. +// +// If you have downloaded a copy of the tRPC source code from Tencent, +// please note that tRPC source code is licensed under the Apache 2.0 License, +// A copy of the Apache 2.0 License is included in this file. +// +// + +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" + +#include "trpc/client/make_client_context.h" +#include "trpc/client/trpc_client.h" +#include "trpc/common/config/trpc_config.h" +#include "trpc/common/runtime_manager.h" +#include "trpc/common/status.h" +#include "trpc/common/trpc_plugin.h" +#include "trpc/coroutine/fiber.h" +#include "trpc/coroutine/fiber_latch.h" +#include "trpc/util/log/logging.h" + +#include "examples/features/trpc_stream/server/stream.trpc.pb.h" + +DEFINE_string(service_name, "trpc.test.shopping.StreamShopping", "callee service name"); +DEFINE_string(client_config, "trpc_cpp_fiber.yaml", ""); +DEFINE_string(addr, "127.0.0.1:24757", "ip:port"); +DEFINE_string(rpc_method, "ClientStreamShopping", "RPC method name"); +DEFINE_int32(request_count, 3, "count of request"); + +namespace test::shopping { + +using StreamShoppingServiceProxy = ::trpc::test::shopping::StreamShoppingServiceProxy; +using StreamShoppingServiceProxyPtr = std::shared_ptr; + +int GetRequestCount(int request_count) { + if (request_count > 0) { + return std::min(100, FLAGS_request_count); + } + return std::numeric_limits::max(); +} + +bool CallClientStreamShopping(const StreamShoppingServiceProxyPtr& proxy, int request_count) { + auto context = ::trpc::MakeClientContext(proxy); + ::trpc::test::shopping::ShoppingReply reply; + int send_count{0}; + int send_bytes{0}; + ::trpc::Status status{0, 0, "OK"}; + bool ok{true}; + do { + auto stream = proxy->ClientStreamShopping(context, &reply); + if (!stream.GetStatus().OK()) { + std::cerr << "stream error:" << stream.GetStatus().ToString() << std::endl; + ok = false; + break; + } + + for (int i = 0; i < request_count; ++i) { + ::trpc::test::shopping::ShoppingRequest request; + request.set_msg("ClientStreamShopping#" + std::to_string(i + 1)); + status = stream.Write(request); + if (status.OK()) { + ++send_count; + send_bytes += request.msg().size(); + std::cout << "send request msg:" << request.msg() << std::endl; + continue; + } + break; + } + + // Check: last write is ok. + if (status.OK()) { + status = stream.WriteDone(); + if (status.OK()) { + // Waits the final status of the RPC calling. + status = stream.Finish(); + } else { + std::cerr << "write done error: " << status.ToString() << std::endl; + ok = false; + } + } else { + std::cerr << "write error: " << status.ToString() << std::endl; + ok = false; + } + + if (status.OK()) { + std::cout << "stream rpc succeed, send count: " << send_count << ", send bytes: " << send_bytes + << ", reply: " << reply.msg() << std::endl; + } else { + std::cerr << "stream rpc failed:" << status.ToString() << std::endl; + ok = false; + } + } while (0); + return ok; +} + +bool CallServerStreamShopping(const StreamShoppingServiceProxyPtr& proxy) { + auto context = ::trpc::MakeClientContext(proxy); + ::trpc::test::shopping::ShoppingRequest request; + request.set_msg("ServerStreamShopping"); + auto stream = proxy->ServerStreamShopping(context, request); + ::trpc::Status status{0, 0, "OK"}; + bool ok{true}; + do { + if (!stream.GetStatus().OK()) { + std::cout << "stream error:" << stream.GetStatus().ToString() << std::endl; + ok = false; + break; + } + + for (;;) { + ::trpc::test::shopping::ShoppingReply reply; + status = stream.Read(&reply, 2000); + if (status.OK()) { + std::cout << "stream got reply:" << reply.msg() << std::endl; + continue; + } + if (status.StreamEof()) { + std::cout << "got EOF" << std::endl; + // Waits the final status of the RPC calling. + status = stream.Finish(); + } + break; + } + + if (status.OK()) { + std::cout << "stream rpc succeed" << std::endl; + } else { + std::cerr << "stream rpc failed:" << status.ToString() << std::endl; + ok = false; + } + } while (0); + return ok; +} + +bool CallBidiStreamShopping(const StreamShoppingServiceProxyPtr& proxy, int request_count) { + auto context = ::trpc::MakeClientContext(proxy); + auto stream = proxy->BidiStreamShopping(context); + ::trpc::Status status{0, 0, "OK"}; + bool ok{true}; + do { + if (!stream.GetStatus().OK()) { + std::cout << "stream error:" << stream.GetStatus().ToString() << std::endl; + ok = false; + break; + } + + for (int i = 0; i < request_count; ++i) { + std::stringstream request_msg; + request_msg << "BidiStreamShopping, " << (i + 1); + ::trpc::test::shopping::ShoppingRequest request; + request.set_msg(request_msg.str()); + status = stream.Write(request); + if (status.OK()) { + continue; + } + break; + } + + if (!status.OK()) { + std::cerr << "write error: " << status.ToString() << std::endl; + ok = false; + break; + } + + status = stream.WriteDone(); + if (!status.OK()) { + std::cerr << "write done error: " << status.ToString() << std::endl; + ok = false; + break; + } + + ::trpc::test::shopping::ShoppingReply reply; + for (;;) { + status = stream.Read(&reply, 2000); + if (status.OK()) { + std::stringstream reply_msg; + reply_msg << "stream got reply:" << reply.msg(); + std::cout << reply_msg.str() << std::endl; + continue; + } + if (status.StreamEof()) { + std::cout << "got EOF" << std::endl; + // Waits the final status of the RPC calling. + status = stream.Finish(); + } + break; + } + + if (status.OK()) { + std::cout << "stream rpc succeed" << std::endl; + } else { + std::cerr << "stream rpc failed:" << status.ToString() << std::endl; + ok = false; + } + } while (0); + return ok; +} + +namespace test::shopping { + +bool CallPurchase(const StreamShoppingServiceProxyPtr& proxy, int purchase_count) { + auto context = ::trpc::MakeClientContext(proxy); + ::trpc::test::shopping::ShoppingRequest request; + ::trpc::test::shopping::ShoppingReply reply; + request.set_purchase_count(purchase_count); + + ::trpc::Status status = proxy->Purchase(context, request, &reply); + if (status.OK()) { + std::cout << "抢购结果: " << (reply.success() ? "成功" : "失败") << std::endl; + std::cout << "消息: " << reply.msg() << std::endl; + std::cout << "剩余库存: " << reply.remaining_stock() << std::endl; + return reply.success(); + } else { + std::cerr << "调用失败: " << status.ToString() << std::endl; + return false; + } +} + +int Run() { + bool final_ok{true}; + + struct calling_args_t { + std::string calling_name; + std::function calling_executor; + bool ok; + }; + std::vector callings{}; + + int request_count = GetRequestCount(FLAGS_request_count); + std::string rpc_method = FLAGS_rpc_method; + + ::trpc::ServiceProxyOption option; + option.name = FLAGS_service_name; + option.codec_name = "trpc"; + option.network = "tcp"; + option.conn_type = "long"; + option.timeout = 1000; + option.selector_name = "direct"; + option.target = FLAGS_addr; + + auto stream_shopping_proxy = + ::trpc::GetTrpcClient()->GetProxy(FLAGS_service_name, option); + + std::string calling_name{""}; + std::function calling_executor{nullptr}; + if (rpc_method == "ClientStreamShopping") { + calling_name = "Streaming RPC, ClientStreamShopping"; + calling_executor = [&stream_shopping_proxy, request_count]() { + return CallClientStreamShopping(stream_shopping_proxy, request_count); + }; + } else if (rpc_method == "ServerStreamShopping") { + calling_name = "Streaming RPC, ServerStreamShopping"; + calling_executor = [&stream_shopping_proxy]() { return CallServerStreamShopping(stream_shopping_proxy); }; + } else if (rpc_method == "BidiStreamShopping") { + calling_name = "Streaming RPC, BidiStreamShopping"; + calling_executor = [&stream_shopping_proxy, request_count]() { + return CallBidiStreamShopping(stream_shopping_proxy, request_count); + }; + } else if (rpc_method == "Purchase") { + calling_name = "RPC, Purchase"; + calling_executor = [&stream_shopping_proxy, request_count]() { + return CallPurchase(stream_shopping_proxy, request_count); + }; + } else { + std::cout << "RPC method is empty, nothing todo" << std::endl; + return 0; + } + // Executing multiple cases is to send concurrent requests. + for (int i = 0; i < 8; i++) { + callings.push_back({calling_name + std::to_string(i + 1), calling_executor, false}); + } + + auto latch_count = static_cast(callings.size()); + ::trpc::FiberLatch callings_latch{latch_count}; + + for (auto& c : callings) { + ::trpc::StartFiberDetached([&callings_latch, &c]() { + c.ok = c.calling_executor(); + callings_latch.CountDown(); + }); + } + callings_latch.Wait(); + + for (const auto& c : callings) { + final_ok &= c.ok; + std::cout << "name: " << c.calling_name << ", ok: " << c.ok << std::endl; + } + + std::cout << "final result of streaming RPC calling: " << final_ok << std::endl; + return final_ok ? 0 : -1; +} + +} // namespace test::shopping + +bool ParseClientConfig(int argc, char* argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + google::CommandLineFlagInfo info; + if (GetCommandLineFlagInfo("client_config", &info) && info.is_default) { + std::cerr << "start client with client_config, for example: " << argv[0] + << " --client_config=/client/client_config/filepath" << std::endl; + return false; + } + std::cout << "FLAGS_service_name: " << FLAGS_service_name << std::endl; + std::cout << "FLAGS_client_config: " << FLAGS_client_config << std::endl; + std::cout << "FLAGS_addr: " << FLAGS_addr << std::endl; + std::cout << "FLAGS_rpc_method: " << FLAGS_rpc_method << std::endl; + std::cout << "FLAGS_request_count: " << FLAGS_request_count << std::endl; + return true; +} + +int main(int argc, char* argv[]) { + if (!ParseClientConfig(argc, argv)) { + exit(-1); + } + + if (::trpc::TrpcConfig::GetInstance()->Init(FLAGS_client_config) != 0) { + std::cerr << "load client_config failed." << std::endl; + exit(-1); + } + + // If the business code is running in trpc pure client mode, the business code needs to be running in the + // `RunInTrpcRuntime` function + return ::trpc::RunInTrpcRuntime([]() { return test::helloworld::Run(); }); +} diff --git a/examples/features/trpc_stream/client/rawdata_stream_client_shopping.cc b/examples/features/trpc_stream/client/rawdata_stream_client_shopping.cc new file mode 100644 index 00000000..72f31e08 --- /dev/null +++ b/examples/features/trpc_stream/client/rawdata_stream_client_shopping.cc @@ -0,0 +1,208 @@ +// +// +// Tencent is pleased to support the open source community by making tRPC available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. +// All rights reserved. +// +// If you have downloaded a copy of the tRPC source code from Tencent, +// please note that tRPC source code is licensed under the Apache 2.0 License, +// A copy of the Apache 2.0 License is included in this file. +// +// + +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" + +#include "trpc/client/make_client_context.h" +#include "trpc/client/rpc_service_proxy.h" +#include "trpc/client/trpc_client.h" +#include "trpc/common/config/trpc_config.h" +#include "trpc/common/runtime_manager.h" +#include "trpc/common/status.h" +#include "trpc/common/trpc_plugin.h" +#include "trpc/coroutine/fiber.h" +#include "trpc/coroutine/fiber_latch.h" +#include "trpc/util/log/logging.h" + +DEFINE_string(service_name, "trpc.test.shopping.RawDataStreamService", "callee service name"); +DEFINE_string(client_config, "trpc_cpp_fiber.yaml", ""); +DEFINE_string(addr, "127.0.0.1:24758", "ip:port"); +DEFINE_string(rpc_method, "Purchase", "RPC method name"); +DEFINE_int32(request_count, 3, "count of request"); + +namespace test::shopping { + +int GetRequestCount(int request_count) { + if (request_count > 0) { + return std::min(100, FLAGS_request_count); + } + return std::numeric_limits::max(); +} + +bool CallRawDataStreamReadWrite(const std::shared_ptr<::trpc::RpcServiceProxy>& proxy, int request_count) { + auto context = ::trpc::MakeClientContext(proxy); + // Setting the serialization type to kNoopType means that no serialization will be performed. + context->SetReqEncodeType(::trpc::serialization::kNoopType); + // Set the service interface name, which needs to be consistent with the method name registered on the server side. + context->SetFuncName("/trpc.test.shopping.RawDataStreamService/RawDataReadWrite"); + + auto stream = proxy->StreamInvoke<::trpc::NoncontiguousBuffer, ::trpc::NoncontiguousBuffer>(context); + ::trpc::Status status{0, 0, "OK"}; + bool ok{true}; + do { + if (!stream.GetStatus().OK()) { + std::cout << "stream error:" << stream.GetStatus().ToString() << std::endl; + ok = false; + break; + } + + for (int i = 0; i < request_count; ++i) { + std::stringstream request_msg; + request_msg << "RawDataBidiStream, " << (i + 1); + ::trpc::NoncontiguousBufferBuilder builder; + builder.Append(request_msg.str()); + status = stream.Write(builder.DestructiveGet()); + if (status.OK()) { + continue; + } + break; + } + + if (!status.OK()) { + std::cerr << "write error: " << status.ToString() << std::endl; + ok = false; + break; + } + + status = stream.WriteDone(); + if (!status.OK()) { + std::cerr << "write done error: " << status.ToString() << std::endl; + ok = false; + break; + } + + ::trpc::NoncontiguousBuffer reply; + for (;;) { + status = stream.Read(&reply, 2000); + if (status.OK()) { + std::stringstream reply_msg; + reply_msg << "stream got reply:" << ::trpc::FlattenSlow(reply); + std::cout << reply_msg.str() << std::endl; + continue; + } + if (status.StreamEof()) { + std::cout << "<< got EOF" << std::endl; + // Waits the final status of the RPC calling. + status = stream.Finish(); + } + break; + } + + if (status.OK()) { + std::cout << "<< stream rpc succeed" << std::endl; + } else { + std::cout << "<< stream rpc failed:" << status.ToString() << std::endl; + ok = false; + } + } while (0); + return ok; +} + +int Run() { + bool final_ok{true}; + + struct calling_args_t { + std::string calling_name; + std::function calling_executor; + bool ok; + }; + std::vector callings{}; + + int request_count = GetRequestCount(FLAGS_request_count); + std::string rpc_method = FLAGS_rpc_method; + + ::trpc::ServiceProxyOption option; + option.name = FLAGS_service_name; + option.codec_name = "trpc"; + option.network = "tcp"; + option.conn_type = "long"; + option.timeout = 1000; + option.selector_name = "direct"; + option.target = FLAGS_addr; + + auto raw_data_stream_proxy = ::trpc::GetTrpcClient()->GetProxy<::trpc::RpcServiceProxy>(FLAGS_service_name, option); + + std::string calling_name{""}; + std::function calling_executor{nullptr}; + if (rpc_method == "RawDataReadWrite") { + calling_name = "Streaming RPC, RawDataReadWrite"; + calling_executor = [&raw_data_stream_proxy, request_count]() { + return CallRawDataStreamReadWrite(raw_data_stream_proxy, request_count); + }; + } else { + std::cout << "RPC method is empty, nothing todo" << std::endl; + return 0; + } + // Executing multiple cases is to send concurrent requests. + for (int i = 0; i < 8; i++) { + callings.push_back({calling_name + std::to_string(i + 1), calling_executor, false}); + } + + auto latch_count = static_cast(callings.size()); + ::trpc::FiberLatch callings_latch{latch_count}; + + for (auto& c : callings) { + ::trpc::StartFiberDetached([&callings_latch, &c]() { + c.ok = c.calling_executor(); + callings_latch.CountDown(); + }); + } + callings_latch.Wait(); + + for (const auto& c : callings) { + final_ok &= c.ok; + std::cout << "name: " << c.calling_name << ", ok: " << c.ok << std::endl; + } + + std::cout << "final result of streaming RPC calling: " << final_ok << std::endl; + return final_ok ? 0 : -1; +} + +} // namespace test::shopping + +bool ParseClientConfig(int argc, char* argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + google::CommandLineFlagInfo info; + if (GetCommandLineFlagInfo("client_config", &info) && info.is_default) { + std::cerr << "start client with client_config, for example: " << argv[0] + << " --client_config=/client/client_config/filepath" << std::endl; + return false; + } + std::cout << "FLAGS_service_name: " << FLAGS_service_name << std::endl; + std::cout << "FLAGS_client_config: " << FLAGS_client_config << std::endl; + std::cout << "FLAGS_addr: " << FLAGS_addr << std::endl; + std::cout << "FLAGS_rpc_method: " << FLAGS_rpc_method << std::endl; + std::cout << "FLAGS_request_count: " << FLAGS_request_count << std::endl; + return true; +} + +int main(int argc, char* argv[]) { + if (!ParseClientConfig(argc, argv)) { + exit(-1); + } + + if (::trpc::TrpcConfig::GetInstance()->Init(FLAGS_client_config) != 0) { + std::cerr << "load client_config failed." << std::endl; + exit(-1); + } + + // If the business code is running in trpc pure client mode, the business code needs to be running in the + // `RunInTrpcRuntime` function + return ::trpc::RunInTrpcRuntime([]() { return test::helloworld::Run(); }); +} diff --git a/examples/features/trpc_stream/server/stream_server_shopping.cc b/examples/features/trpc_stream/server/stream_server_shopping.cc new file mode 100644 index 00000000..401c0e40 --- /dev/null +++ b/examples/features/trpc_stream/server/stream_server_shopping.cc @@ -0,0 +1,53 @@ +// +// +// Tencent is pleased to support the open source community by making tRPC available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. +// All rights reserved. +// +// If you have downloaded a copy of the tRPC source code from Tencent, +// please note that tRPC source code is licensed under the Apache 2.0 License, +// A copy of the Apache 2.0 License is included in this file. +// +// + +#include +#include + +#include "trpc/common/logging/trpc_logging.h" +#include "trpc/common/trpc_app.h" + +#include "examples/features/trpc_stream/server/stream_service_shopping.h" + +namespace test { +namespace shopping { + +class ShoppingStreamServer : public ::trpc::TrpcApp { + public: + int Initialize() override { + const auto& config = ::trpc::TrpcConfig::GetInstance()->GetServerConfig(); + std::string service_name = fmt::format("{}.{}.{}.{}", "trpc", config.app, config.server, "StreamShopping"); + TRPC_FMT_INFO("service name:{}", service_name); + RegisterService(service_name, std::make_shared<::test::shopping::StreamShoppingServiceImpl>()); + + service_name = fmt::format("{}.{}.{}.{}", "trpc", config.app, config.server, "RawDataStreamService"); + TRPC_FMT_INFO("service name:{}", service_name); + RegisterService(service_name, std::make_shared<::test::shopping::RawDataStreamService>()); + + return 0; + } + + void Destroy() override {} +}; + +} // namespace shopping +} // namespace test + +int main(int argc, char** argv) { + test::shopping::ShoppingStreamServer server; + + server.Main(argc, argv); + server.Wait(); + + return 0; +} diff --git a/examples/features/trpc_stream/server/stream_service_shopping.cc b/examples/features/trpc_stream/server/stream_service_shopping.cc new file mode 100644 index 00000000..a441922b --- /dev/null +++ b/examples/features/trpc_stream/server/stream_service_shopping.cc @@ -0,0 +1,189 @@ +// +// +// Tencent is pleased to support the open source community by making tRPC available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. +// All rights reserved. +// +// If you have downloaded a copy of the tRPC source code from Tencent, +// please note that tRPC source code is licensed under the Apache 2.0 License, +// A copy of the Apache 2.0 License is included in this file. +// +// + +#include "examples/features/trpc_stream/server/stream_service_shopping.h" + +#include + +#include "gflags/gflags.h" + +// 全局库存变量 +int stock = 100; + +namespace test::shopping { + +// Client streaming RPC. +::trpc::Status StreamShoppingServiceImpl::ClientStreamShopping( + const ::trpc::ServerContextPtr& context, + const ::trpc::stream::StreamReader<::trpc::test::shopping::ShoppingRequest>& reader, + ::trpc::test::shopping::ShoppingReply* reply) { + ::trpc::Status status{}; + uint32_t request_counter{0}; + uint32_t request_bytes{0}; + for (;;) { + ::trpc::test::shopping::ShoppingRequest request{}; + status = reader.Read(&request, 3000); + if (status.OK()) { + ++request_counter; + request_bytes += request.msg().size(); + TRPC_FMT_INFO("server got request: {}", request.msg()); + continue; + } + if (status.StreamEof()) { + std::stringstream reply_msg; + reply_msg << "server got EOF, reply to client, server got request" + << ", count:" << request_counter << ", received bytes:" << request_bytes; + reply->set_msg(reply_msg.str()); + TRPC_FMT_INFO("reply to the client: {}", reply_msg.str()); + status = ::trpc::Status{0, 0, "OK"}; + break; + } + TRPC_FMT_ERROR("stream got error: {}", status.ToString()); + break; + } + return status; +} + +// Server streaming RPC. +::trpc::Status StreamShoppingServiceImpl::ServerStreamShopping( + const ::trpc::ServerContextPtr& context, + const ::trpc::test::shopping::ShoppingRequest& request, // NO LINT + ::trpc::stream::StreamWriter<::trpc::test::shopping::ShoppingReply>* writer) { + ::trpc::Status status{}; + // A simple case, try to reply 10 response messages to the client. + int request_count = 10; + for (int i = 0; i < request_count; ++i) { + std::stringstream reply_msg; + reply_msg << " reply: " << request.msg() << "#" << (i + 1); + ::trpc::test::shopping::ShoppingReply reply{}; + reply.set_msg(reply_msg.str()); + status = writer->Write(reply); + if (status.OK()) { + continue; + } + TRPC_FMT_ERROR("stream got error: {}", status.ToString()); + break; + } + return status; +} + +// Bi-direction streaming RPC. +::trpc::Status StreamShoppingServiceImpl::BidiStreamShopping( + const ::trpc::ServerContextPtr& context, + const ::trpc::stream::StreamReader<::trpc::test::shopping::ShoppingRequest>& reader, + ::trpc::stream::StreamWriter<::trpc::test::shopping::ShoppingReply>* writer) { + std::vector msg_list{}; + ::trpc::Status status{}; + uint32_t request_counter{0}; + uint32_t request_bytes{0}; + for (;;) { + ::trpc::test::shopping::ShoppingRequest request{}; + status = reader.Read(&request, 3000); + if (status.OK()) { + ++request_counter; + request_bytes += request.msg().size(); + std::stringstream reply_msg; + reply_msg << " reply:" << request_counter << ", received bytes:" << request_bytes; + ::trpc::test::shopping::ShoppingReply reply; + reply.set_msg(reply_msg.str()); + writer->Write(reply); + continue; + } + if (status.StreamEof()) { + std::stringstream reply_msg; + reply_msg << "server got EOF, reply to client, server got request" + << ", count:" << request_counter << ", received bytes:" << request_bytes; + ::trpc::test::shopping::ShoppingReply reply; + reply.set_msg(reply_msg.str()); + status = writer->Write(reply); + } + // ERROR. + break; + } + if (!status.OK()) { + TRPC_FMT_ERROR("stream go error: {}", status.ToString()); + } + return status; +} + +RawDataStreamService::RawDataStreamService() { + AddRpcServiceMethod(new ::trpc::RpcServiceMethod( + "/trpc.test.shopping.RawDataStreamService/RawDataReadWrite", ::trpc::MethodType::BIDI_STREAMING, + new ::trpc::StreamRpcMethodHandler<::trpc::NoncontiguousBuffer, ::trpc::NoncontiguousBuffer>( + std::bind(&RawDataStreamService::RawDataReadWrite, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3)))); +} + +::trpc::Status RawDataStreamService::RawDataReadWrite( + const ::trpc::ServerContextPtr& context, + const ::trpc::stream::StreamReader<::trpc::NoncontiguousBuffer>& reader, // NO LINT + ::trpc::stream::StreamWriter<::trpc::NoncontiguousBuffer>* writer) { + std::vector msg_list{}; + ::trpc::Status status{}; + uint32_t request_counter{0}; + uint32_t request_bytes{0}; + for (;;) { + ::trpc::NoncontiguousBuffer request{}; + status = reader.Read(&request, 3000); + if (status.OK()) { + ++request_counter; + request_bytes += request.ByteSize(); + TRPC_FMT_INFO("Recv msg: {}", ::trpc::FlattenSlow(request)); + std::stringstream reply_msg; + reply_msg << " reply:" << request_counter << ", received bytes:" << request_bytes; + ::trpc::NoncontiguousBufferBuilder builder; + builder.Append(reply_msg.str()); + writer->Write(builder.DestructiveGet()); + continue; + } + if (status.StreamEof()) { + TRPC_FMT_INFO("Recv eof"); + std::stringstream reply_msg; + reply_msg << "server got EOF, reply to client, server got request" + << ", count:" << request_counter << ", received bytes:" << request_bytes; + ::trpc::NoncontiguousBufferBuilder builder; + builder.Append(reply_msg.str()); + status = writer->Write(builder.DestructiveGet()); + } + // ERROR. + break; + } + if (!status.OK()) { + TRPC_FMT_ERROR("stream go error: {}", status.ToString()); + } + return status; +} + +// 实现 Purchase 方法 +::trpc::Status StreamShoppingServiceImpl::Purchase(const ::trpc::ServerContextPtr& context, + const ::trpc::test::shopping::ShoppingRequest& request, + ::trpc::test::shopping::ShoppingReply* reply) { + int purchase_count = request.purchase_count(); + if (purchase_count <= 0) { + reply->set_msg("购买件数必须大于 0"); + reply->set_success(false); + reply->set_remaining_stock(stock); + } else if (stock >= purchase_count) { + stock -= purchase_count; + reply->set_msg("抢购成功"); + reply->set_success(true); + reply->set_remaining_stock(stock); + } else { + reply->set_msg("库存不足,抢购失败"); + reply->set_success(false); + reply->set_remaining_stock(stock); + } + return ::trpc::Status::OK(); +} + +} // namespace test::shopping diff --git a/examples/features/trpc_stream/server/stream_service_shopping.h b/examples/features/trpc_stream/server/stream_service_shopping.h new file mode 100644 index 00000000..1a40ba7b --- /dev/null +++ b/examples/features/trpc_stream/server/stream_service_shopping.h @@ -0,0 +1,57 @@ +// +// +// Tencent is pleased to support the open source community by making tRPC available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. +// All rights reserved. +// +// If you have downloaded a copy of the tRPC source code from Tencent, +// please note that tRPC source code is licensed under the Apache 2.0 License, +// A copy of the Apache 2.0 License is included in this file. +// +// + +#pragma once + +#include "trpc/common/status.h" +#include "trpc/server/stream_rpc_method_handler.h" +#include "trpc/util/buffer/noncontiguous_buffer.h" + +#include "examples/features/trpc_stream/server/stream.trpc.pb.h" + +namespace test::shopping { + +class StreamShoppingServiceImpl : public ::trpc::test::shopping::StreamShopping { + public: + // Client streaming RPC. + ::trpc::Status ClientStreamShopping( + const ::trpc::ServerContextPtr& context, + const ::trpc::stream::StreamReader<::trpc::test::shopping::ShoppingRequest>& reader, + ::trpc::test::shopping::ShoppingReply* reply) override; + + // Server streaming RPC. + ::trpc::Status ServerStreamShopping( + const ::trpc::ServerContextPtr& context, const ::trpc::test::shopping::ShoppingRequest& request, + ::trpc::stream::StreamWriter<::trpc::test::shopping::ShoppingReply>* writer) override; + + // Bi-direction streaming RPC. + ::trpc::Status BidiStreamShopping( + const ::trpc::ServerContextPtr& context, + const ::trpc::stream::StreamReader<::trpc::test::shopping::ShoppingRequest>& reader, + ::trpc::stream::StreamWriter<::trpc::test::shopping::ShoppingReply>* writer) override; + + // 新增抢购方法声明 + ::trpc::Status Purchase(const ::trpc::ServerContextPtr& context, + const ::trpc::test::shopping::ShoppingRequest& request, + ::trpc::test::shopping::ShoppingReply* reply) override; +}; + +class RawDataStreamService : public ::trpc::RpcServiceImpl { + public: + RawDataStreamService(); + ::trpc::Status RawDataReadWrite(const ::trpc::ServerContextPtr& context, + const ::trpc::stream::StreamReader<::trpc::NoncontiguousBuffer>& reader, + ::trpc::stream::StreamWriter<::trpc::NoncontiguousBuffer>* writer); +}; + +} // namespace test::shopping diff --git a/examples/features/trpc_stream/server/stream_shopping.proto b/examples/features/trpc_stream/server/stream_shopping.proto new file mode 100644 index 00000000..f9f4fbda --- /dev/null +++ b/examples/features/trpc_stream/server/stream_shopping.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package trpc.test.shopping; + +service StreamShopping { + // Client streaming + rpc ClientStreamShopping (stream ShoppingRequest) returns (ShoppingReply) {} + + // Server streaming + rpc ServerStreamShopping (ShoppingRequest) returns (stream ShoppingReply) {} + + // Bidi streaming + rpc BidiStreamShopping (stream ShoppingRequest) returns (stream ShoppingReply) {} +} + +message ShoppingRequest { + string msg = 1; + int32 purchase_count = 2; +} + +message ShoppingReply { + string msg = 1; + bool success = 2; + int32 remaining_stock = 3; +} diff --git a/trpc/util/http/http_sse_event.h b/trpc/util/http/http_sse_event.h new file mode 100644 index 00000000..ef2e8e4a --- /dev/null +++ b/trpc/util/http/http_sse_event.h @@ -0,0 +1,43 @@ +#include +#include +#include +namespace trpc::http { + +/// @brief SSE event structure +struct SseEvent { + std::string event_type; // event field + std::string data; // data field + std::string id; // id field + std::optional retry; // retry field (milliseconds) + + /// @brief Serialize SSE event to string format + std::string ToString() const { + std::string result; + + if (!event_type.empty()) { + result += "event: " + event_type + "\n"; + } + + if (!data.empty()) { + // Handle multi-line data + std::istringstream iss(data); + std::string line; + while (std::getline(iss, line)) { + result += "data: " + line + "\n"; + } + } + + if (!id.empty()) { + result += "id: " + id + "\n"; + } + + if (retry.has_value()) { + result += "retry: " + std::to_string(retry.value()) + "\n"; + } + + result += "\n"; // End with double newline + return result; + } +}; + +} // namespace trpc::http \ No newline at end of file diff --git a/trpc/util/http/http_sse_event_test.cc b/trpc/util/http/http_sse_event_test.cc new file mode 100644 index 00000000..15ec04e6 --- /dev/null +++ b/trpc/util/http/http_sse_event_test.cc @@ -0,0 +1,234 @@ +// +// +// Tencent is pleased to support the open source community by making tRPC available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. +// All rights reserved. +// +// If you have downloaded a copy of the tRPC source code from Tencent, +// please note that tRPC source code is licensed under the Apache 2.0 License, +// A copy of the Apache 2.0 License is included in this file. +// +// + +#include "trpc/client/http/http_sse_event.h" + +#include + +namespace trpc::http { + +class SseEventTest : public ::testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(SseEventTest, ToStringBasicMessage) { + SseEvent event; + event.data = "This is the first message."; + + std::string result = event.ToString(); + std::string expected = "data: This is the first message.\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringMultiLineMessage) { + SseEvent event; + event.data = "This is the second message, it\\nhas two lines."; + + std::string result = event.ToString(); + std::string expected = + "data: This is the second message, it\\n" + "data: has two lines.\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringWithEventType) { + SseEvent event; + event.event_type = "add"; + event.data = "73857293"; + + std::string result = event.ToString(); + std::string expected = + "event: add\\n" + "data: 73857293\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringRemoveEvent) { + SseEvent event; + event.event_type = "remove"; + event.data = "2153"; + + std::string result = event.ToString(); + std::string expected = + "event: remove\\n" + "data: 2153\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringCompleteEvent) { + SseEvent event; + event.event_type = "notification"; + event.data = "Hello World"; + event.id = "123"; + event.retry = 5000; + + std::string result = event.ToString(); + std::string expected = + "event: notification\\n" + "data: Hello World\\n" + "id: 123\\n" + "retry: 5000\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringEmptyEvent) { + SseEvent event; + + std::string result = event.ToString(); + std::string expected = "\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringOnlyId) { + SseEvent event; + event.id = "msg-001"; + + std::string result = event.ToString(); + std::string expected = "id: msg-001\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringOnlyRetry) { + SseEvent event; + event.retry = 3000; + + std::string result = event.ToString(); + std::string expected = "retry: 3000\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringMultipleDataLines) { + SseEvent event; + event.event_type = "multiline"; + event.data = "Line 1\\nLine 2\\nLine 3"; + event.id = "multi-001"; + + std::string result = event.ToString(); + std::string expected = + "event: multiline\\n" + "data: Line 1\\n" + "data: Line 2\\n" + "data: Line 3\\n" + "id: multi-001\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringEmptyData) { + SseEvent event; + event.event_type = "ping"; + event.data = ""; + + std::string result = event.ToString(); + std::string expected = "event: ping\\n\\n"; + + EXPECT_EQ(result, expected); +} + +TEST_F(SseEventTest, ToStringSpecialCharacters) { + SseEvent event; + event.event_type = "special"; + event.data = "Data with: colons and\\nnewlines"; + event.id = "special-123"; + + std::string result = event.ToString(); + std::string expected = + "event: special\\n" + "data: Data with: colons and\\n" + "data: newlines\\n" + "id: special-123\\n\\n"; + + EXPECT_EQ(result, expected); +} + +// Test sequence of events as described in the requirements +TEST_F(SseEventTest, TestSequenceOfBasicMessages) { + std::vector events; + + // First message + SseEvent event1; + event1.data = "This is the first message."; + events.push_back(event1); + + // Second message (multi-line) + SseEvent event2; + event2.data = "This is the second message, it\\nhas two lines."; + events.push_back(event2); + + // Third message + SseEvent event3; + event3.data = "This is the third message."; + events.push_back(event3); + + std::string combined_output; + for (const auto& event : events) { + combined_output += event.ToString(); + } + + std::string expected = + "data: This is the first message.\\n\\n" + "data: This is the second message, it\\n" + "data: has two lines.\\n\\n" + "data: This is the third message.\\n\\n"; + + EXPECT_EQ(combined_output, expected); +} + +TEST_F(SseEventTest, TestSequenceOfTypedEvents) { + std::vector events; + + // Add event + SseEvent event1; + event1.event_type = "add"; + event1.data = "73857293"; + events.push_back(event1); + + // Remove event + SseEvent event2; + event2.event_type = "remove"; + event2.data = "2153"; + events.push_back(event2); + + // Another add event + SseEvent event3; + event3.event_type = "add"; + event3.data = "113411"; + events.push_back(event3); + + std::string combined_output; + for (const auto& event : events) { + combined_output += event.ToString(); + } + + std::string expected = + "event: add\\n" + "data: 73857293\\n\\n" + "event: remove\\n" + "data: 2153\\n\\n" + "event: add\\n" + "data: 113411\\n\\n"; + + EXPECT_EQ(combined_output, expected); +} + +} // namespace trpc::http \ No newline at end of file diff --git a/trpc/util/http/http_sse_parser.h b/trpc/util/http/http_sse_parser.h new file mode 100644 index 00000000..1f7992c9 --- /dev/null +++ b/trpc/util/http/http_sse_parser.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include +#include +#include + +#include "trpc/client/http/http_sse_event.h" + +namespace trpc::http { + +/// @brief SSE parser for parsing text messages to SseEvent objects +class SseParser { + public: + /// @brief Parse SSE text stream into SseEvent objects + /// @param text SSE formatted text stream + /// @return Vector of parsed SseEvent objects + static std::vector Parse(const std::string& text) { + std::vector events; + std::istringstream stream(text); + std::string line; + + SseEvent current_event; + std::vector data_lines; + + while (std::getline(stream, line)) { + // Remove \\r if present (handle \\r\\n line endings) + if (!line.empty() && line.back() == '\\r') { + line.pop_back(); + } + + // Empty line indicates end of event + if (line.empty()) { + if (!data_lines.empty() || !current_event.event_type.empty() || !current_event.id.empty() || + current_event.retry.has_value()) { + // Join data lines with newlines + if (!data_lines.empty()) { + current_event.data = JoinDataLines(data_lines); + } + + // Set default event type if not specified + if (current_event.event_type.empty()) { + current_event.event_type = "message"; + } + + events.push_back(current_event); + + // Reset for next event + current_event = SseEvent{}; + data_lines.clear(); + } + continue; + } + + // Parse field lines + auto colon_pos = line.find(':'); + if (colon_pos == std::string::npos) { + continue; // Skip malformed lines + } + + std::string field = line.substr(0, colon_pos); + std::string value = line.substr(colon_pos + 1); + + // Remove leading space from value if present + if (!value.empty() && value[0] == ' ') { + value = value.substr(1); + } + + if (field == "data") { + data_lines.push_back(value); + } else if (field == "event") { + current_event.event_type = value; + } else if (field == "id") { + current_event.id = value; + } else if (field == "retry") { + try { + current_event.retry = std::stoi(value); + } catch (const std::exception&) { + // Ignore invalid retry values + } + } + // Ignore unknown fields as per SSE specification + } + + // Handle case where stream doesn't end with empty line + if (!data_lines.empty() || !current_event.event_type.empty() || !current_event.id.empty() || + current_event.retry.has_value()) { + if (!data_lines.empty()) { + current_event.data = JoinDataLines(data_lines); + } + if (current_event.event_type.empty()) { + current_event.event_type = "message"; + } + events.push_back(current_event); + } + + return events; + } + + /// @brief Serialize SseEvent object to text format + /// @param event SseEvent object to serialize + /// @return SSE formatted text string + static std::string Serialize(const SseEvent& event) { return event.ToString(); } + + /// @brief Serialize multiple SseEvent objects to text format + /// @param events Vector of SseEvent objects to serialize + /// @return SSE formatted text stream + static std::string SerializeMultiple(const std::vector& events) { + std::string result; + for (const auto& event : events) { + result += event.ToString(); + } + return result; + } + + private: + /// @brief Join data lines with newlines + static std::string JoinDataLines(const std::vector& lines) { + if (lines.empty()) return ""; + if (lines.size() == 1) return lines[0]; + + std::string result = lines[0]; + for (size_t i = 1; i < lines.size(); ++i) { + result += "\\n" + lines[i]; + } + return result; + } +}; + +} // namespace trpc::http \ No newline at end of file