Some ruminations on tag dispatching


Motivation

Imagine that you have the following matrix type

template<typename T>
class Matrix {};

Here we have chosen a very simple type but in practice matrix types are generally parametrized by a lot of other template parameters like matrix dense or sparse type, symmetric shape, dynamic/static dimensions… To convince you, you can have a look at:

just to cite few.

The problems begin when you want to branch to the most specialized/effective subroutines to do some computations. For instance if you want to perform a matrix product you want to call BLAS or BLIS when T is a supported type and provide a generic subroutines otherwise.

template <typename T>
void matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                   const Matrix<T>& B)
{
  static_assert(!std::is_same<T, T>::value, /* differed false */
                "Missing specialization");
}
template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type
    matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                  const Matrix<T>& B)
{
  std::cout << "\nUse a generic (in-house) implementation";
}

template <typename T>
typename std::enable_if<std::is_same<T, float>::value||std::is_same<T, double>::value>::type
    matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                  const Matrix<T>& B)
{
  std::cout << "\nUse a BLAS call";
}
int main()
{
 struct A
  {
  };

  Matrix<A> aResult, aM;
  Matrix<int> iResult, iM;
  Matrix<double> dResult, dM;

  // matrixProduct(aResult, aM, aM);  // static assert error OK
  matrixProduct(iResult, iM, iM);  // Error!
  matrixProduct(dResult, dM, dM);  // Error!
}

The compiler error message is

...
error: call of overloaded ‘matrixProduct(Matrix<double>&, Matrix<double>&, Matrix<double>&)’ is ambiguous
   matrixProduct(dResult, dM, dM);  // Error!
...

A quick but dirty fix would be:

template <typename T>  
typename std::enable_if<!std::is_arithmetic<T>::value>::type 
matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                   const Matrix<T>& B)
{
  static_assert(!std::is_same<T, T>::value, /* differed false */
                "Missing specialization");
}
template <typename T>
typename std::enable_if<(std::is_arithmetic<T>::value) &&
                        (!(std::is_same<T, float>::value ||
                           std::is_same<T, double>::value))>::type
    matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                  const Matrix<T>& B)
{
  std::cout << "\nUse a generic (in-house) implementation";
}

template <typename T>
typename std::enable_if<std::is_same<T, float>::value ||
                        std::is_same<T, double>::value>::type
    matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                  const Matrix<T>& B)
{
  std::cout << "\nUse a BLAS call";
}

However we clearly see that SFINAE is not the right solution as we have to manually manage mutual exclusion to avoid ambiguities. This becomes untrackable as soon as the number of possible specializations increase.

Tag dispatch

The problem can be classically solved using tag dispatch to disambiguate the subroutine call:

struct UndefinedTag;

template <typename T>
void matrixProduct(const UndefinedTag&, Matrix<T>& AB,
                   const Matrix<T>& A, const Matrix<T>& B)
{
  static_assert(!std::is_same<T, T>::value, /* differed false */
                "Missing specialization");
}

struct GenericTag;

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type
    matrixProduct(const GenericTag&, Matrix<T>& AB,
                  const Matrix<T>& A, const Matrix<T>& B)
{
  std::cout << "\nUse a generic (in-house) implementation";
}

struct BlasTag;

template <typename T>
typename std::enable_if<std::is_same<T, float>::value ||
                        std::is_same<T, double>::value>::type
    matrixProduct(const BlasTag&, Matrix<T>& AB, const Matrix<T>& A,
                  const Matrix<T>& B)
{
  std::cout << "\nUse a BLAS call";
}

// Define a Tag hierarchy
// -> this induce an order in the dispatch
struct UndefinedTag
{
};

struct GenericTag : UndefinedTag
{
};

struct BlasTag : GenericTag
{
};

template <typename T>
void matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                   const Matrix<T>& B)
{
  matrixProduct(BlasTag(), AB, A, B);
}

What I do not like with this approach is:

  • forward declarations of tags are required,
  • tag hierarchy is defined outside the main matrixProduct function,
  • we do not have a fine control if we want to mix local and global priority setup

Attempt for a more flexible implementation

We can avoid tag forward declaration by using a “generic” priority mechanism defined as follow:

template <unsigned int N>
struct PriorityTag : PriorityTag<N - 1>
{
};

template <>
struct PriorityTag<0>
{
};

Now the implementation would be:

template <typename T>
void matrixProduct(const PriorityTag<0>&, Matrix<T>& AB,
                   const Matrix<T>& A, const Matrix<T>& B)
{
  static_assert(!std::is_same<T, T>::value, /* differed false */
                "Missing specialization");
}

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type
    matrixProduct(const PriorityTag<1>&, Matrix<T>& AB,
                  const Matrix<T>& A, const Matrix<T>& B)
{
  std::cout << "\nUse a generic (in-house) implementation";
}

template <typename T>
typename std::enable_if<std::is_same<T, float>::value ||
                        std::is_same<T, double>::value>::type
    matrixProduct(const PriorityTag<2>&, Matrix<T>& AB, const Matrix<T>& A,
                  const Matrix<T>& B)
{
  std::cout << "\nUse a BLAS call";
}

template <typename T>
void matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                   const Matrix<T>& B)
{
  matrixProduct(PriorityTag<2>(), AB, A, B);
}

Now we do not need forward declarations anymore but the resulting code is not easy to understand. It would be better to have something like UseBlas instead of PriorityTag<2>. Moreover we want a mechanism to easily modify priority order.

With c++14, the enum class comes to the rescue, we get:

template <typename PRIORITY, PRIORITY integer>
using PriorityConfiguration = PriorityTag<static_cast<
    typename std::underlying_type<PRIORITY>::type>(integer)>;

template <typename PRIORITY, typename T>
void matrixProduct(
    const PriorityConfiguration<PRIORITY, PRIORITY::Undefined>&,
    Matrix<T>& AB, const Matrix<T>& A, const Matrix<T>& B)
{
  static_assert(!std::is_same<T, T>::value, /* differed false */
                "Missing specialization");
}

template <typename PRIORITY, typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type
    matrixProduct(
        const PriorityConfiguration<PRIORITY, PRIORITY::Generic>&,
        Matrix<T>& AB, const Matrix<T>& A, const Matrix<T>& B)
{
  std::cout << "\nUse a generic (in-house) implementation";
}

template <typename PRIORITY, typename T>
typename std::enable_if<std::is_same<T, float>::value ||
                        std::is_same<T, double>::value>::type
    matrixProduct(
        const PriorityConfiguration<PRIORITY, PRIORITY::Blas>&,
        Matrix<T>& AB, const Matrix<T>& A, const Matrix<T>& B)
{
  std::cout << "\nUse a BLAS call";
}

template <typename T>
void matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                   const Matrix<T>& B)
{
  enum class LocalPriority : unsigned int
  {
    Undefined,
    Generic,
    Blas, 

    END
  };

  matrixProduct<LocalPriority>(
      PriorityConfiguration<LocalPriority, LocalPriority::END>(), AB,
      A, B);
}

With this approach it is very easy to:

  • modify priority
  enum class LocalPriority : unsigned int
  {
    Undefined,
    Blas,      // instead of Generic
    Generic,   // instead of Blas

    END
  };
  • dismiss some specializations
  enum class LocalPriority : unsigned int
  {
    Undefined,
    Generic,

    END,

    Blas
  };
  • use a local or a global priority setup
enum class GlobalPriority : unsigned int
{
  Undefined,
  Generic,
  Blas,
  Blis,
  Static_Size,

  END,

};

template <typename T>
void matrixProduct(Matrix<T>& AB, const Matrix<T>& A,
                   const Matrix<T>& B)
{
  matrixProduct<GlobalPriority>(
      PriorityConfiguration<GlobalPriority, GlobalPriority::END>(),
      AB, A, B);
}

Another way, constexpr if

There is another solution, but still not “directly” available because it would need c++17 constexpr if. However it can be partially implemented using c++14. See an interesting blog post from Baptiste Wicht and the associated implementation. Compared to a “true” constexpr_if this implementation does not allow to return different types. Anyway here is the code:

template <typename T>
void matrixProduct_undefined(Matrix<T>& AB, const Matrix<T>& A,
                             const Matrix<T>& B)
{
  static_assert(!std::is_same<T, T>::value, /* differed false */
                "Missing specialization");
}

template <typename T>
constexpr bool matrixProduct_generic_v = std::is_arithmetic<T>::value;

template <typename T>
typename std::enable_if<matrixProduct_generic_v<T>>::type
    matrixProduct_generic(Matrix<T>& AB, const Matrix<T>& A,
                          const Matrix<T>& B)
{
  std::cout << "\nUse a generic (in-house) implementation";
}

template <typename T>
constexpr bool matrixProduct_blas_v =
    std::is_same<T, float>::value || std::is_same<T, double>::value;

template <typename T>
typename std::enable_if<matrixProduct_blas_v<T>>::type
    matrixProduct_blas(Matrix<T>& AB, const Matrix<T>& A,
                       const Matrix<T>& B)
{
  std::cout << "\nUse a BLAS call";
}

template <typename T>
void matrixProduct(Matrix<T>& AB, const Matrix<T>& A, const Matrix<T>& B)
{
    static_if<matrixProduct_blas_v<T>>([&](auto id)
                                       {
                                           matrixProduct_blas(id(AB), id(A), id(B));
                                       })
        .else_([&](auto id)
               {
                   static_if<matrixProduct_generic_v<T>>(
                       [&](auto id)
                       {
                           matrixProduct_generic(id(AB), id(A), id(B));
                       })
                       .else_([&](auto id)
                              {
                                  matrixProduct_undefined(id(AB), id(A), id(B));
                              });
               });
}

The code to define a static_if is there:

// Code from:
// https://github.com/wichtounet/cpp_utils/blob/master/static_if.hpp
//
// See:
// http://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html
// http://lists.boost.org/Archives/boost/2014/08/216607.php
//
namespace static_if_detail
{
  struct identity
  {
    template <typename T>
    T operator()(T&& x) const
    {
      return std::forward<T>(x);
    }
  };

  template <bool Cond>
  struct statement
  {
    template <typename F>
    void then(const F& f)
    {
      f(identity());
    }

    template <typename F>
    void else_(const F&)
    {
    }
  };

  template <>
  struct statement<false>
  {
    template <typename F>
    void then(const F&)
    {
    }

    template <typename F>
    void else_(const F& f)
    {
      f(identity());
    }
  };
}

template <bool Cond, typename F>
static_if_detail::statement<Cond> static_if(F const& f)
{
  static_if_detail::statement<Cond> if_;
  if_.then(f);
  return if_;
}

Final word

For the moment I use the priority like solution.

You can find the code on github.

One thought on “Some ruminations on tag dispatching

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.