Skip to content

Commit

Permalink
Fix accessing destroyed objects in the callback of async_wait (#362)
Browse files Browse the repository at this point in the history
Fixes #358
Fixes #359

### Motivation

`async_wait` is not used correctly in some places. A callback that
captures the `this` pointer or reference to `this` is passed to
`async_wait`, if this object is destroyed when the callback is called,
an invalid memory access will happen.

### Modifications

Use the following pattern in all `async_wait` calls.

```c++
std::weak_ptr<T> weakSelf{shared_from_this()};
timer_->async_wait([weakSelf](/* ... */) {
    if (auto self = weakSelf.lock()) {
        self->foo();
    }
});
```
  • Loading branch information
BewareMyPower authored Dec 6, 2023
1 parent 63d494f commit 24ab12c
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 30 deletions.
12 changes: 7 additions & 5 deletions lib/ConsumerImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, const std::string& topic,
consumerId_(client->newConsumerId()),
consumerStr_("[" + topic + ", " + subscriptionName + ", " + std::to_string(consumerId_) + "] "),
messageListenerRunning_(true),
negativeAcksTracker_(client, *this, conf),
negativeAcksTracker_(std::make_shared<NegativeAcksTracker>(client, *this, conf)),
readCompacted_(conf.isReadCompacted()),
startMessageId_(startMessageId),
maxPendingChunkedMessage_(conf.getMaxPendingChunkedMessage()),
Expand All @@ -104,6 +104,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, const std::string& topic,
} else {
unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerDisabled());
}
unAckedMessageTrackerPtr_->start();

// Setup stats reporter.
unsigned int statsIntervalInSeconds = client->getClientConfig().getStatsIntervalInSeconds();
Expand Down Expand Up @@ -1227,7 +1228,7 @@ std::pair<MessageId, bool> ConsumerImpl::prepareCumulativeAck(const MessageId& m

void ConsumerImpl::negativeAcknowledge(const MessageId& messageId) {
unAckedMessageTrackerPtr_->remove(messageId);
negativeAcksTracker_.add(messageId);
negativeAcksTracker_->add(messageId);
}

void ConsumerImpl::disconnectConsumer() {
Expand Down Expand Up @@ -1265,7 +1266,7 @@ void ConsumerImpl::closeAsync(ResultCallback originalCallback) {
if (ackGroupingTrackerPtr_) {
ackGroupingTrackerPtr_->close();
}
negativeAcksTracker_.close();
negativeAcksTracker_->close();

ClientConnectionPtr cnx = getCnx().lock();
if (!cnx) {
Expand Down Expand Up @@ -1303,7 +1304,7 @@ void ConsumerImpl::shutdown() {
if (client) {
client->cleanupConsumer(this);
}
negativeAcksTracker_.close();
negativeAcksTracker_->close();
cancelTimers();
consumerCreatedPromise_.setFailed(ResultAlreadyClosed);
failPendingReceiveCallback();
Expand Down Expand Up @@ -1608,7 +1609,7 @@ void ConsumerImpl::internalGetLastMessageIdAsync(const BackoffPtr& backoff, Time
}

void ConsumerImpl::setNegativeAcknowledgeEnabledForTesting(bool enabled) {
negativeAcksTracker_.setEnabledForTesting(enabled);
negativeAcksTracker_->setEnabledForTesting(enabled);
}

void ConsumerImpl::trackMessage(const MessageId& messageId) {
Expand Down Expand Up @@ -1695,6 +1696,7 @@ void ConsumerImpl::cancelTimers() noexcept {
boost::system::error_code ec;
batchReceiveTimer_->cancel(ec);
checkExpiredChunkedTimer_->cancel(ec);
unAckedMessageTrackerPtr_->stop();
}

void ConsumerImpl::processPossibleToDLQ(const MessageId& messageId, ProcessDLQCallBack cb) {
Expand Down
2 changes: 1 addition & 1 deletion lib/ConsumerImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class ConsumerImpl : public ConsumerImplBase {
CompressionCodecProvider compressionCodecProvider_;
UnAckedMessageTrackerPtr unAckedMessageTrackerPtr_;
BrokerConsumerStatsImpl brokerConsumerStats_;
NegativeAcksTracker negativeAcksTracker_;
std::shared_ptr<NegativeAcksTracker> negativeAcksTracker_;
AckGroupingTrackerPtr ackGroupingTrackerPtr_;

MessageCryptoPtr msgCrypto_;
Expand Down
1 change: 1 addition & 0 deletions lib/MultiTopicsConsumerImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ MultiTopicsConsumerImpl::MultiTopicsConsumerImpl(ClientImplPtr client, const std
} else {
unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerDisabled());
}
unAckedMessageTrackerPtr_->start();
auto partitionsUpdateInterval = static_cast<unsigned int>(client->conf().getPartitionsUpdateInterval());
if (partitionsUpdateInterval > 0) {
partitionsUpdateTimer_ = listenerExecutor_->createDeadlineTimer();
Expand Down
7 changes: 6 additions & 1 deletion lib/NegativeAcksTracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ void NegativeAcksTracker::scheduleTimer() {
if (closed_) {
return;
}
std::weak_ptr<NegativeAcksTracker> weakSelf{shared_from_this()};
timer_->expires_from_now(timerInterval_);
timer_->async_wait(std::bind(&NegativeAcksTracker::handleTimer, this, std::placeholders::_1));
timer_->async_wait([weakSelf](const boost::system::error_code &ec) {
if (auto self = weakSelf.lock()) {
self->handleTimer(ec);
}
});
}

void NegativeAcksTracker::handleTimer(const boost::system::error_code &ec) {
Expand Down
2 changes: 1 addition & 1 deletion lib/NegativeAcksTracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using DeadlineTimerPtr = std::shared_ptr<boost::asio::deadline_timer>;
class ExecutorService;
using ExecutorServicePtr = std::shared_ptr<ExecutorService>;

class NegativeAcksTracker {
class NegativeAcksTracker : public std::enable_shared_from_this<NegativeAcksTracker> {
public:
NegativeAcksTracker(ClientImplPtr client, ConsumerImpl &consumer, const ConsumerConfiguration &conf);

Expand Down
17 changes: 13 additions & 4 deletions lib/PatternMultiTopicsConsumerImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ const PULSAR_REGEX_NAMESPACE::regex PatternMultiTopicsConsumerImpl::getPattern()
void PatternMultiTopicsConsumerImpl::resetAutoDiscoveryTimer() {
autoDiscoveryRunning_ = false;
autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod()));
autoDiscoveryTimer_->async_wait(
std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, this, std::placeholders::_1));

auto weakSelf = weak_from_this();
autoDiscoveryTimer_->async_wait([weakSelf](const boost::system::error_code& err) {
if (auto self = weakSelf.lock()) {
self->autoDiscoveryTimerTask(err);
}
});
}

void PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask(const boost::system::error_code& err) {
Expand Down Expand Up @@ -222,8 +227,12 @@ void PatternMultiTopicsConsumerImpl::start() {

if (conf_.getPatternAutoDiscoveryPeriod() > 0) {
autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod()));
autoDiscoveryTimer_->async_wait(
std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, this, std::placeholders::_1));
auto weakSelf = weak_from_this();
autoDiscoveryTimer_->async_wait([weakSelf](const boost::system::error_code& err) {
if (auto self = weakSelf.lock()) {
self->autoDiscoveryTimerTask(err);
}
});
}
}

Expand Down
4 changes: 4 additions & 0 deletions lib/PatternMultiTopicsConsumerImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class PatternMultiTopicsConsumerImpl : public MultiTopicsConsumerImpl {
void onTopicsRemoved(NamespaceTopicsPtr removedTopics, ResultCallback callback);
void handleOneTopicAdded(const Result result, const std::string& topic,
std::shared_ptr<std::atomic<int>> topicsNeedCreate, ResultCallback callback);

std::weak_ptr<PatternMultiTopicsConsumerImpl> weak_from_this() noexcept {
return std::static_pointer_cast<PatternMultiTopicsConsumerImpl>(shared_from_this());
}
};

} // namespace pulsar
Expand Down
19 changes: 10 additions & 9 deletions lib/UnAckedMessageTrackerEnabled.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ void UnAckedMessageTrackerEnabled::timeoutHandler() {
ExecutorServicePtr executorService = client_->getIOExecutorProvider()->get();
timer_ = executorService->createDeadlineTimer();
timer_->expires_from_now(boost::posix_time::milliseconds(tickDurationInMs_));
timer_->async_wait([&](const boost::system::error_code& ec) {
if (ec) {
LOG_DEBUG("Ignoring timer cancelled event, code[" << ec << "]");
} else {
timeoutHandler();
std::weak_ptr<UnAckedMessageTrackerEnabled> weakSelf{shared_from_this()};
timer_->async_wait([weakSelf](const boost::system::error_code& ec) {
auto self = weakSelf.lock();
if (self && !ec) {
self->timeoutHandler();
}
});
}
Expand Down Expand Up @@ -91,10 +91,10 @@ UnAckedMessageTrackerEnabled::UnAckedMessageTrackerEnabled(long timeoutMs, long
std::set<MessageId> msgIds;
timePartitions.push_back(msgIds);
}

timeoutHandler();
}

void UnAckedMessageTrackerEnabled::start() { timeoutHandler(); }

bool UnAckedMessageTrackerEnabled::add(const MessageId& msgId) {
std::lock_guard<std::recursive_mutex> acquire(lock_);
auto id = discardBatch(msgId);
Expand Down Expand Up @@ -172,9 +172,10 @@ void UnAckedMessageTrackerEnabled::clear() {
}
}

UnAckedMessageTrackerEnabled::~UnAckedMessageTrackerEnabled() {
void UnAckedMessageTrackerEnabled::stop() {
boost::system::error_code ec;
if (timer_) {
timer_->cancel();
timer_->cancel(ec);
}
}
} /* namespace pulsar */
19 changes: 11 additions & 8 deletions lib/UnAckedMessageTrackerEnabled.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <boost/asio/deadline_timer.hpp>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <set>

Expand All @@ -34,19 +35,21 @@ class ConsumerImplBase;
using ClientImplPtr = std::shared_ptr<ClientImpl>;
using DeadlineTimerPtr = std::shared_ptr<boost::asio::deadline_timer>;

class UnAckedMessageTrackerEnabled : public UnAckedMessageTrackerInterface {
class UnAckedMessageTrackerEnabled : public std::enable_shared_from_this<UnAckedMessageTrackerEnabled>,
public UnAckedMessageTrackerInterface {
public:
~UnAckedMessageTrackerEnabled();
UnAckedMessageTrackerEnabled(long timeoutMs, ClientImplPtr, ConsumerImplBase&);
UnAckedMessageTrackerEnabled(long timeoutMs, long tickDuration, ClientImplPtr, ConsumerImplBase&);
bool add(const MessageId& msgId);
bool remove(const MessageId& msgId);
void remove(const MessageIdList& msgIds);
void removeMessagesTill(const MessageId& msgId);
void removeTopicMessage(const std::string& topic);
void start() override;
void stop() override;
bool add(const MessageId& msgId) override;
bool remove(const MessageId& msgId) override;
void remove(const MessageIdList& msgIds) override;
void removeMessagesTill(const MessageId& msgId) override;
void removeTopicMessage(const std::string& topic) override;
void timeoutHandler();

void clear();
void clear() override;

protected:
void timeoutHandlerHelper();
Expand Down
2 changes: 2 additions & 0 deletions lib/UnAckedMessageTrackerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class UnAckedMessageTrackerInterface {
public:
virtual ~UnAckedMessageTrackerInterface() {}
UnAckedMessageTrackerInterface() {}
virtual void start() {}
virtual void stop() {}
virtual bool add(const MessageId& m) = 0;
virtual bool remove(const MessageId& m) = 0;
virtual void remove(const MessageIdList& msgIds) = 0;
Expand Down
2 changes: 2 additions & 0 deletions tests/BasicEndToEndTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3973,6 +3973,7 @@ TEST(BasicEndToEndTest, testUnAckedMessageTrackerEnabledIndividualAck) {

auto tracker0 = std::make_shared<UnAckedMessageTrackerEnabledMock>(unAckedMessagesTimeoutMs,
clientImplPtr, consumerImpl0);
tracker0->start();
ASSERT_EQ(tracker0->getUnAckedMessagesTimeoutMs(), unAckedMessagesTimeoutMs);
ASSERT_EQ(tracker0->getTickDurationInMs(), unAckedMessagesTimeoutMs);

Expand Down Expand Up @@ -4048,6 +4049,7 @@ TEST(BasicEndToEndTest, testUnAckedMessageTrackerEnabledCumulativeAck) {
}
auto tracker = std::make_shared<UnAckedMessageTrackerEnabledMock>(unAckedMessagesTimeoutMs, clientImplPtr,
consumerImpl0);
tracker->start();
for (auto idx = 0; idx < numMsg; ++idx) {
ASSERT_TRUE(tracker->add(recvMsgId[idx]));
}
Expand Down
4 changes: 3 additions & 1 deletion tests/ConsumerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ TEST(ConsumerTest, testRedeliveryOfDecryptionFailedMessages) {
auto consumer2ImplPtr = PulsarFriend::getConsumerImplPtr(consumer2);
consumer2ImplPtr->unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerEnabled(
100, 100, PulsarFriend::getClientImplPtr(client), static_cast<ConsumerImplBase&>(*consumer2ImplPtr)));
consumer2ImplPtr->unAckedMessageTrackerPtr_->start();

ConsumerConfiguration consConfig3;
consConfig3.setConsumerType(pulsar::ConsumerShared);
Expand All @@ -1003,6 +1004,7 @@ TEST(ConsumerTest, testRedeliveryOfDecryptionFailedMessages) {
auto consumer3ImplPtr = PulsarFriend::getConsumerImplPtr(consumer3);
consumer3ImplPtr->unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerEnabled(
100, 100, PulsarFriend::getClientImplPtr(client), static_cast<ConsumerImplBase&>(*consumer3ImplPtr)));
consumer3ImplPtr->unAckedMessageTrackerPtr_->start();

int numberOfMessages = 20;
std::string msgContent = "msg-content";
Expand Down Expand Up @@ -1222,7 +1224,7 @@ TEST(ConsumerTest, testNegativeAcksTrackerClose) {

consumer.close();
auto consumerImplPtr = PulsarFriend::getConsumerImplPtr(consumer);
ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_.nackedMessages_.empty());
ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_->nackedMessages_.empty());

client.close();
}
Expand Down

0 comments on commit 24ab12c

Please sign in to comment.