From 6c92eabcfe60a3cfd43b51cedec7ce41061a5d62 Mon Sep 17 00:00:00 2001
From: tzik <tzik@chromium.org>
Date: Fri, 25 Nov 2016 07:56:36 -0800
Subject: [PATCH] Support external task cancellation mechanisms in
 base::Callback::IsCancelled

This CL cleans up the implementation of base::Callback::IsCancelled(),
and exposes an injection point as BindCancellationChecker, so that
external code can implement the cancellation handling.
Plus, this CL adds a specialization of BindCancellationChecker for
blink::TaskRunner::Runner, so that the blink task scheduler can handle a
cancellation of a task posted via WebTaskRunner::postCancellableTask() properly.

Review-Url: https://codereview.chromium.org/2487493004
Cr-Commit-Position: refs/heads/master@{#434511}
---
 base/bind_helpers.h                           | 45 ++++++++++++
 base/bind_internal.h                          | 69 ++++++++-----------
 third_party/WebKit/Source/platform/DEPS       |  1 +
 .../WebKit/Source/platform/WebTaskRunner.cpp  | 23 ++++++-
 .../WebKit/Source/platform/WebTaskRunner.h    |  3 +-
 .../Source/platform/WebTaskRunnerTest.cpp     | 51 ++++++++++++++
 .../scheduler/test/fake_web_task_runner.cc    |  4 ++
 .../scheduler/test/fake_web_task_runner.h     |  5 +-
 8 files changed, 156 insertions(+), 45 deletions(-)

diff --git a/base/bind_helpers.h b/base/bind_helpers.h
index c7c7be8ee840a..1e65225572906 100644
--- a/base/bind_helpers.h
+++ b/base/bind_helpers.h
@@ -179,6 +179,9 @@ struct BindUnwrapTraits;
 
 namespace internal {
 
+template <typename Functor, typename SFINAE = void>
+struct FunctorTraits;
+
 template <typename T>
 class UnretainedWrapper {
  public:
@@ -521,6 +524,48 @@ struct BindUnwrapTraits<internal::PassedWrapper<T>> {
   }
 };
 
+// CallbackCancellationTraits allows customization of Callback's cancellation
+// semantics. By default, callbacks are not cancellable. A specialization should
+// set is_cancellable = true and implement an IsCancelled() that returns if the
+// callback should be cancelled.
+template <typename Functor, typename BoundArgsTuple, typename SFINAE = void>
+struct CallbackCancellationTraits {
+  static constexpr bool is_cancellable = false;
+};
+
+// Specialization for method bound to weak pointer receiver.
+template <typename Functor, typename... BoundArgs>
+struct CallbackCancellationTraits<
+    Functor,
+    std::tuple<BoundArgs...>,
+    typename std::enable_if<
+        internal::IsWeakMethod<internal::FunctorTraits<Functor>::is_method,
+                               BoundArgs...>::value>::type> {
+  static constexpr bool is_cancellable = true;
+
+  template <typename Receiver, typename... Args>
+  static bool IsCancelled(const Functor&,
+                          const Receiver& receiver,
+                          const Args&...) {
+    return !receiver;
+  }
+};
+
+// Specialization for a nested bind.
+template <typename Signature,
+          typename... BoundArgs,
+          internal::CopyMode copy_mode,
+          internal::RepeatMode repeat_mode>
+struct CallbackCancellationTraits<Callback<Signature, copy_mode, repeat_mode>,
+                                  std::tuple<BoundArgs...>> {
+  static constexpr bool is_cancellable = true;
+
+  template <typename Functor>
+  static bool IsCancelled(const Functor& functor, const BoundArgs&...) {
+    return functor.IsCancelled();
+  }
+};
+
 }  // namespace base
 
 #endif  // BASE_BIND_HELPERS_H_
diff --git a/base/bind_internal.h b/base/bind_internal.h
index 88e764547f816..8988bdca226de 100644
--- a/base/bind_internal.h
+++ b/base/bind_internal.h
@@ -130,7 +130,7 @@ struct ForceVoidReturn<R(Args...)> {
 // FunctorTraits<>
 //
 // See description at top of file.
-template <typename Functor, typename SFINAE = void>
+template <typename Functor, typename SFINAE>
 struct FunctorTraits;
 
 // For a callable type that is convertible to the corresponding function type.
@@ -387,43 +387,24 @@ IsNull(const Functor&) {
   return false;
 }
 
-template <typename Functor, typename... BoundArgs>
-struct BindState;
+// Used by ApplyCancellationTraits below.
+template <typename Functor, typename BoundArgsTuple, size_t... indices>
+bool ApplyCancellationTraitsImpl(const Functor& functor,
+                                 const BoundArgsTuple& bound_args,
+                                 IndexSequence<indices...>) {
+  return CallbackCancellationTraits<Functor, BoundArgsTuple>::IsCancelled(
+      functor, base::get<indices>(bound_args)...);
+}
 
-template <typename BindStateType, typename SFINAE = void>
-struct CancellationChecker {
-  static constexpr bool is_cancellable = false;
-  static bool Run(const BindStateBase*) {
-    return false;
-  }
-};
-
-template <typename Functor, typename... BoundArgs>
-struct CancellationChecker<
-    BindState<Functor, BoundArgs...>,
-    typename std::enable_if<IsWeakMethod<FunctorTraits<Functor>::is_method,
-                                         BoundArgs...>::value>::type> {
-  static constexpr bool is_cancellable = true;
-  static bool Run(const BindStateBase* base) {
-    using BindStateType = BindState<Functor, BoundArgs...>;
-    const BindStateType* bind_state = static_cast<const BindStateType*>(base);
-    return !base::get<0>(bind_state->bound_args_);
-  }
-};
-
-template <typename Signature,
-          typename... BoundArgs,
-          CopyMode copy_mode,
-          RepeatMode repeat_mode>
-struct CancellationChecker<
-    BindState<Callback<Signature, copy_mode, repeat_mode>, BoundArgs...>> {
-  static constexpr bool is_cancellable = true;
-  static bool Run(const BindStateBase* base) {
-    using Functor = Callback<Signature, copy_mode, repeat_mode>;
-    using BindStateType = BindState<Functor, BoundArgs...>;
-    const BindStateType* bind_state = static_cast<const BindStateType*>(base);
-    return bind_state->functor_.IsCancelled();
-  }
+// Relays |base| to corresponding CallbackCancellationTraits<>::Run(). Returns
+// true if the callback |base| represents is canceled.
+template <typename BindStateType>
+bool ApplyCancellationTraits(const BindStateBase* base) {
+  const BindStateType* storage = static_cast<const BindStateType*>(base);
+  static constexpr size_t num_bound_args =
+      std::tuple_size<decltype(storage->bound_args_)>::value;
+  return ApplyCancellationTraitsImpl(storage->functor_, storage->bound_args_,
+                                     MakeIndexSequence<num_bound_args>());
 };
 
 // Template helpers to detect using Bind() on a base::Callback without any
@@ -449,14 +430,17 @@ struct BindingCallbackWithNoArgs<Callback<Signature, copy_mode, repeat_mode>,
 template <typename Functor, typename... BoundArgs>
 struct BindState final : BindStateBase {
   using IsCancellable = std::integral_constant<
-      bool, CancellationChecker<BindState>::is_cancellable>;
+      bool,
+      CallbackCancellationTraits<Functor,
+                                 std::tuple<BoundArgs...>>::is_cancellable>;
 
   template <typename ForwardFunctor, typename... ForwardBoundArgs>
   explicit BindState(BindStateBase::InvokeFuncStorage invoke_func,
                      ForwardFunctor&& functor,
                      ForwardBoundArgs&&... bound_args)
-      // IsCancellable is std::false_type if the CancellationChecker<>::Run
-      // returns always false. Otherwise, it's std::true_type.
+      // IsCancellable is std::false_type if
+      // CallbackCancellationTraits<>::IsCancelled returns always false.
+      // Otherwise, it's std::true_type.
       : BindState(IsCancellable{},
                   invoke_func,
                   std::forward<ForwardFunctor>(functor),
@@ -476,8 +460,9 @@ struct BindState final : BindStateBase {
                      BindStateBase::InvokeFuncStorage invoke_func,
                      ForwardFunctor&& functor,
                      ForwardBoundArgs&&... bound_args)
-      : BindStateBase(invoke_func, &Destroy,
-                      &CancellationChecker<BindState>::Run),
+      : BindStateBase(invoke_func,
+                      &Destroy,
+                      &ApplyCancellationTraits<BindState>),
         functor_(std::forward<ForwardFunctor>(functor)),
         bound_args_(std::forward<ForwardBoundArgs>(bound_args)...) {
     DCHECK(!IsNull(functor_));
diff --git a/third_party/WebKit/Source/platform/DEPS b/third_party/WebKit/Source/platform/DEPS
index 06327975b84e7..c15cda80aafc4 100644
--- a/third_party/WebKit/Source/platform/DEPS
+++ b/third_party/WebKit/Source/platform/DEPS
@@ -2,6 +2,7 @@ include_rules = [
     # To whitelist base/ stuff Blink is allowed to include, we list up all
     # directories and files instead of writing 'base/'.
     "+base/bind.h",
+    "+base/bind_helpers.h",
     "+base/callback.h",
     "+base/callback_forward.h",
     "+base/files",
diff --git a/third_party/WebKit/Source/platform/WebTaskRunner.cpp b/third_party/WebKit/Source/platform/WebTaskRunner.cpp
index 059abdf83d05f..aa2d25449dc39 100644
--- a/third_party/WebKit/Source/platform/WebTaskRunner.cpp
+++ b/third_party/WebKit/Source/platform/WebTaskRunner.cpp
@@ -4,8 +4,29 @@
 
 #include "platform/WebTaskRunner.h"
 
+#include "base/bind_helpers.h"
 #include "base/single_thread_task_runner.h"
 
+namespace base {
+
+using RunnerMethodType =
+    void (blink::TaskHandle::Runner::*)(const blink::TaskHandle&);
+
+template <>
+struct CallbackCancellationTraits<
+    RunnerMethodType,
+    std::tuple<WTF::WeakPtr<blink::TaskHandle::Runner>, blink::TaskHandle>> {
+  static constexpr bool is_cancellable = true;
+
+  static bool IsCancelled(RunnerMethodType,
+                          const WTF::WeakPtr<blink::TaskHandle::Runner>&,
+                          const blink::TaskHandle& handle) {
+    return !handle.isActive();
+  }
+};
+
+}  // namespace base
+
 namespace blink {
 
 class TaskHandle::Runner : public WTF::ThreadSafeRefCounted<Runner> {
@@ -15,7 +36,7 @@ class TaskHandle::Runner : public WTF::ThreadSafeRefCounted<Runner> {
 
   WTF::WeakPtr<Runner> asWeakPtr() { return m_weakPtrFactory.createWeakPtr(); }
 
-  bool isActive() { return static_cast<bool>(m_task); }
+  bool isActive() const { return m_task && !m_task->isCancelled(); }
 
   void cancel() {
     std::unique_ptr<WTF::Closure> task = std::move(m_task);
diff --git a/third_party/WebKit/Source/platform/WebTaskRunner.h b/third_party/WebKit/Source/platform/WebTaskRunner.h
index 4b233667eaeea..b6cf972f84d58 100644
--- a/third_party/WebKit/Source/platform/WebTaskRunner.h
+++ b/third_party/WebKit/Source/platform/WebTaskRunner.h
@@ -46,8 +46,9 @@ class BLINK_PLATFORM_EXPORT TaskHandle {
   TaskHandle(TaskHandle&&);
   TaskHandle& operator=(TaskHandle&&);
 
- private:
   class Runner;
+
+ private:
   friend class WebTaskRunner;
 
   explicit TaskHandle(RefPtr<Runner>);
diff --git a/third_party/WebKit/Source/platform/WebTaskRunnerTest.cpp b/third_party/WebKit/Source/platform/WebTaskRunnerTest.cpp
index 022a649a53e2a..bcff0c113ae43 100644
--- a/third_party/WebKit/Source/platform/WebTaskRunnerTest.cpp
+++ b/third_party/WebKit/Source/platform/WebTaskRunnerTest.cpp
@@ -18,6 +18,23 @@ void getIsActive(bool* isActive, TaskHandle* handle) {
   *isActive = handle->isActive();
 }
 
+class CancellationTestHelper {
+ public:
+  CancellationTestHelper() : m_weakPtrFactory(this) {}
+
+  WeakPtr<CancellationTestHelper> createWeakPtr() {
+    return m_weakPtrFactory.createWeakPtr();
+  }
+
+  void revokeWeakPtrs() { m_weakPtrFactory.revokeAll(); }
+  void incrementCounter() { ++m_counter; }
+  int counter() const { return m_counter; }
+
+ private:
+  int m_counter = 0;
+  WeakPtrFactory<CancellationTestHelper> m_weakPtrFactory;
+};
+
 }  // namespace
 
 TEST(WebTaskRunnerTest, PostCancellableTaskTest) {
@@ -99,4 +116,38 @@ TEST(WebTaskRunnerTest, PostCancellableTaskTest) {
   EXPECT_FALSE(handle.isActive());
 }
 
+TEST(WebTaskRunnerTest, CancellationCheckerTest) {
+  scheduler::FakeWebTaskRunner taskRunner;
+
+  int count = 0;
+  TaskHandle handle = taskRunner.postCancellableTask(
+      BLINK_FROM_HERE, WTF::bind(&increment, WTF::unretained(&count)));
+  EXPECT_EQ(0, count);
+
+  // TaskHandle::isActive should detect the deletion of posted task.
+  auto queue = taskRunner.takePendingTasksForTesting();
+  ASSERT_EQ(1u, queue.size());
+  EXPECT_FALSE(queue[0].IsCancelled());
+  EXPECT_TRUE(handle.isActive());
+  queue.clear();
+  EXPECT_FALSE(handle.isActive());
+  EXPECT_EQ(0, count);
+
+  count = 0;
+  CancellationTestHelper helper;
+  handle = taskRunner.postCancellableTask(
+      BLINK_FROM_HERE, WTF::bind(&CancellationTestHelper::incrementCounter,
+                                 helper.createWeakPtr()));
+  EXPECT_EQ(0, helper.counter());
+
+  // The cancellation of the posted task should be propagated to TaskHandle.
+  queue = taskRunner.takePendingTasksForTesting();
+  ASSERT_EQ(1u, queue.size());
+  EXPECT_FALSE(queue[0].IsCancelled());
+  EXPECT_TRUE(handle.isActive());
+  helper.revokeWeakPtrs();
+  EXPECT_TRUE(queue[0].IsCancelled());
+  EXPECT_FALSE(handle.isActive());
+}
+
 }  // namespace blink
diff --git a/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.cc b/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.cc
index 2b33ed2e92aa9..0d3866696c4a8 100644
--- a/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.cc
+++ b/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.cc
@@ -115,5 +115,9 @@ void FakeWebTaskRunner::runUntilIdle() {
   }
 }
 
+std::deque<base::Closure> FakeWebTaskRunner::takePendingTasksForTesting() {
+  return std::move(data_->task_queue_);
+}
+
 }  // namespace scheduler
 }  // namespace blink
diff --git a/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.h b/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.h
index 1fbe4db0994aa..4a6c435a90570 100644
--- a/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.h
+++ b/third_party/WebKit/Source/platform/scheduler/test/fake_web_task_runner.h
@@ -5,11 +5,13 @@
 #ifndef THIRD_PARTY_WEBKIT_SOURCE_PLATFORM_SCHEDULER_TEST_FAKE_WEB_TASK_RUNNER_H_
 #define THIRD_PARTY_WEBKIT_SOURCE_PLATFORM_SCHEDULER_TEST_FAKE_WEB_TASK_RUNNER_H_
 
+#include <deque>
+
 #include "base/macros.h"
+#include "base/memory/ref_counted.h"
 #include "platform/WebTaskRunner.h"
 #include "wtf/PassRefPtr.h"
 #include "wtf/RefPtr.h"
-#include "base/memory/ref_counted.h"
 
 namespace blink {
 namespace scheduler {
@@ -35,6 +37,7 @@ class FakeWebTaskRunner : public WebTaskRunner {
   SingleThreadTaskRunner* toSingleThreadTaskRunner() override;
 
   void runUntilIdle();
+  std::deque<base::Closure> takePendingTasksForTesting();
 
  private:
   class Data;