// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. #include "kudu/security/tls_handshake.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "kudu/gutil/macros.h" #include "kudu/security/tls_context.h" #include "kudu/util/countdown_latch.h" #include "kudu/util/monotime.h" #include "kudu/util/net/sockaddr.h" #include "kudu/util/net/socket.h" #include "kudu/util/random.h" #include "kudu/util/random_util.h" #include "kudu/util/scoped_cleanup.h" #include "kudu/util/status.h" #include "kudu/util/test_macros.h" #include "kudu/util/test_util.h" using std::string; using std::thread; using std::unique_ptr; using std::vector; namespace kudu { namespace security { const MonoDelta kTimeout = MonoDelta::FromSeconds(10); // Size is big enough to not fit into output socket buffer of default size // (controlled by setsockopt() with SO_SNDBUF). constexpr size_t kEchoChunkSize = 32 * 1024 * 1024; class TlsSocketTest : public KuduTest { public: void SetUp() override { KuduTest::SetUp(); ASSERT_OK(client_tls_.Init()); } protected: void ConnectClient(const Sockaddr& addr, unique_ptr* sock); TlsContext client_tls_; }; Status DoNegotiationSide(Socket* sock, TlsHandshake* tls, const char* side) { tls->set_verification_mode(TlsVerificationMode::VERIFY_NONE); bool done = false; string received; while (!done) { string to_send; Status s = tls->Continue(received, &to_send); if (s.ok()) { done = true; } else if (!s.IsIncomplete()) { RETURN_NOT_OK_PREPEND(s, "unexpected tls error"); } if (!to_send.empty()) { size_t nwritten; auto deadline = MonoTime::Now() + MonoDelta::FromSeconds(10); RETURN_NOT_OK_PREPEND(sock->BlockingWrite( reinterpret_cast(to_send.data()), to_send.size(), &nwritten, deadline), "error sending"); } if (!done) { uint8_t buf[1024]; int32_t n = 0; RETURN_NOT_OK_PREPEND(sock->Recv(buf, arraysize(buf), &n), "error receiving"); received = string(reinterpret_cast(&buf[0]), n); } } LOG(INFO) << side << ": negotiation complete"; return Status::OK(); } void TlsSocketTest::ConnectClient(const Sockaddr& addr, unique_ptr* sock) { unique_ptr client_sock(new Socket()); ASSERT_OK(client_sock->Init(0)); ASSERT_OK(client_sock->Connect(addr)); TlsHandshake client; ASSERT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, &client)); ASSERT_OK(DoNegotiationSide(client_sock.get(), &client, "client")); ASSERT_OK(client.Finish(&client_sock)); *sock = std::move(client_sock); } class EchoServer { public: EchoServer() : pthread_sync_(1) { } ~EchoServer() { Stop(); Join(); } void Start() { ASSERT_OK(server_tls_.Init()); ASSERT_OK(server_tls_.GenerateSelfSignedCertAndKey()); ASSERT_OK(listen_addr_.ParseString("127.0.0.1", 0)); ASSERT_OK(listener_.Init(0)); ASSERT_OK(listener_.BindAndListen(listen_addr_, /*listen_queue_size=*/10)); ASSERT_OK(listener_.GetSocketAddress(&listen_addr_)); thread_ = thread([&] { pthread_ = pthread_self(); pthread_sync_.CountDown(); unique_ptr sock(new Socket()); Sockaddr remote; CHECK_OK(listener_.Accept(sock.get(), &remote, /*flags=*/0)); TlsHandshake server; CHECK_OK(server_tls_.InitiateHandshake(TlsHandshakeType::SERVER, &server)); CHECK_OK(DoNegotiationSide(sock.get(), &server, "server")); CHECK_OK(server.Finish(&sock)); CHECK_OK(sock->SetRecvTimeout(kTimeout)); unique_ptr buf(new uint8_t[kEchoChunkSize]); // An "echo" loop for kEchoChunkSize byte buffers. while (!stop_) { size_t n; Status s = sock->BlockingRecv(buf.get(), kEchoChunkSize, &n, MonoTime::Now() + kTimeout); if (!s.ok()) { CHECK(stop_) << "unexpected error reading: " << s.ToString(); } LOG(INFO) << "server echoing " << n << " bytes"; size_t written; s = sock->BlockingWrite(buf.get(), n, &written, MonoTime::Now() + kTimeout); if (!s.ok()) { CHECK(stop_) << "unexpected error writing: " << s.ToString(); } if (slow_read_) { SleepFor(MonoDelta::FromMilliseconds(10)); } } }); } void EnableSlowRead() { slow_read_ = true; } const Sockaddr& listen_addr() const { return listen_addr_; } bool stopped() const { return stop_; } void Stop() { stop_ = true; } void Join() { thread_.join(); } const pthread_t& pthread() { pthread_sync_.Wait(); return pthread_; } private: TlsContext server_tls_; Socket listener_; Sockaddr listen_addr_; thread thread_; pthread_t pthread_; CountDownLatch pthread_sync_; std::atomic stop_ { false }; bool slow_read_ = false; }; void handler(int /* signal */) {} // Test for failures to handle EINTR during TLS connection // negotiation and data send/receive. TEST_F(TlsSocketTest, TestTlsSocketInterrupted) { // Set up a no-op signal handler for SIGUSR2. struct sigaction sa, sa_old; memset(&sa, 0, sizeof(sa)); sa.sa_handler = &handler; sigaction(SIGUSR2, &sa, &sa_old); SCOPED_CLEANUP({ sigaction(SIGUSR2, &sa_old, nullptr); }); EchoServer server; NO_FATALS(server.Start()); // Start a thread to send signals to the server thread. thread killer([&]() { while (!server.stopped()) { PCHECK(pthread_kill(server.pthread(), SIGUSR2) == 0); SleepFor(MonoDelta::FromMicroseconds(rand() % 10)); } }); SCOPED_CLEANUP({ killer.join(); }); unique_ptr client_sock; NO_FATALS(ConnectClient(server.listen_addr(), &client_sock)); unique_ptr buf(new uint8_t[kEchoChunkSize]); for (int i = 0; i < 10; i++) { SleepFor(MonoDelta::FromMilliseconds(1)); size_t nwritten; ASSERT_OK(client_sock->BlockingWrite(buf.get(), kEchoChunkSize, &nwritten, MonoTime::Now() + kTimeout)); size_t n; ASSERT_OK(client_sock->BlockingRecv(buf.get(), kEchoChunkSize, &n, MonoTime::Now() + kTimeout)); } server.Stop(); ASSERT_OK(client_sock->Close()); LOG(INFO) << "client done"; } // Return an iovec containing the same data as the buffer 'buf' with the length 'len', // but split into random-sized chunks. The chunks are sized randomly between 1 and // 'max_chunk_size' bytes. vector ChunkIOVec(Random* rng, uint8_t* buf, int len, int max_chunk_size) { vector ret; uint8_t* p = buf; int rem = len; while (rem > 0) { int len = rng->Uniform(max_chunk_size) + 1; len = std::min(len, rem); ret.push_back({p, static_cast(len)}); p += len; rem -= len; } return ret; } // Regression test for KUDU-2218, a bug in which Writev would improperly handle // partial writes in non-blocking mode. TEST_F(TlsSocketTest, TestNonBlockingWritev) { Random rng(GetRandomSeed32()); EchoServer server; server.EnableSlowRead(); NO_FATALS(server.Start()); unique_ptr client_sock; NO_FATALS(ConnectClient(server.listen_addr(), &client_sock)); int sndbuf = 16 * 1024; CHECK_ERR(setsockopt(client_sock->GetFd(), SOL_SOCKET, SO_SNDBUF, &sndbuf, sizeof(sndbuf))); unique_ptr buf(new uint8_t[kEchoChunkSize]); unique_ptr rbuf(new uint8_t[kEchoChunkSize]); RandomString(buf.get(), kEchoChunkSize, &rng); for (int i = 0; i < 10; i++) { ASSERT_OK(client_sock->SetNonBlocking(true)); // Prepare an IOV with the input data split into a bunch of randomly-sized // chunks. vector iov = ChunkIOVec(&rng, buf.get(), kEchoChunkSize, 1024 * 1024); // Loop calling writev until the iov is exhausted int rem = kEchoChunkSize; while (rem > 0) { CHECK(!iov.empty()) << rem; int64_t n; Status s = client_sock->Writev(&iov[0], iov.size(), &n); if (Socket::IsTemporarySocketError(s.posix_code())) { sched_yield(); continue; } ASSERT_OK(s); ASSERT_LE(n, rem); rem -= n; ASSERT_GE(n, 0); while (n > 0) { if (n < iov[0].iov_len) { iov[0].iov_len -= n; iov[0].iov_base = reinterpret_cast(iov[0].iov_base) + n; n = 0; } else { n -= iov[0].iov_len; iov.erase(iov.begin()); } } } LOG(INFO) << "client waiting"; size_t n; ASSERT_OK(client_sock->SetNonBlocking(false)); ASSERT_OK(client_sock->BlockingRecv(rbuf.get(), kEchoChunkSize, &n, MonoTime::Now() + kTimeout)); LOG(INFO) << "client got response"; ASSERT_EQ(0, memcmp(buf.get(), rbuf.get(), kEchoChunkSize)); } server.Stop(); ASSERT_OK(client_sock->Close()); } } // namespace security } // namespace kudu