samedi 22 août 2020

C++ instantiate class explicitly with default template values, except the second-last value

This could be a really simple question.

I have the following (templated) class definition.

template <
    /// Element type for A matrix operand
    typename ElementA_,
    /// Layout type for A matrix operand
    typename LayoutA_,
    /// Element type for B matrix operand
    typename ElementB_,
    /// Layout type for B matrix operand
    typename LayoutB_,
    /// Element type for C and D matrix operands
    typename ElementC_,
    /// Layout type for C and D matrix operands
    typename LayoutC_,
    /// Element type for internal accumulation
    typename ElementAccumulator_ = ElementC_,
    /// Operator class tag
    typename OperatorClass_ = arch::OpClassSimt,
    /// Tag indicating architecture to tune for
    typename ArchTag_ = arch::Sm70,
    //... a lot more typename´s with default values...,
    typename Operator_ = typename DefaultGemmConfiguration<
        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
        ElementAccumulator_>::Operator,
    /// Whether Beta is zero or not
    bool IsBetaZero = false>
class Gemm {
 public:
...

I would like to instantiate that class with all it´s default template types, except for one, the second-last one in the template list actually. How to do this? My guess would be something like this:

using CutlassGemm = cutlass::gemm::device::Gemm<float,        // Data-type of A matrix
                                                  ColumnMajor,  // Layout of A matrix
                                                  float,        // Data-type of B matrix
                                                  ColumnMajor,  // Layout of B matrix
                                                  float,        // Data-type of C matrix
                                                  ColumnMajor,
                                                  default,
                                                  default,
                                                  // a lot more default values
                                                  default,
                                                  arch::MinimumAddOp,// <---- this one should NOT be the default
                                                  default>; // Layout of C matrix

  // Define a CUTLASS GEMM type
  CutlassGemm gemm_operator;

Aucun commentaire:

Enregistrer un commentaire