dimanche 16 juin 2019

Nesting of subexpressions in expression templates

We are writing an expression template library to handle operations on values with a sparse gradient vector (first order automatic differentiation). I am trying to figure out how to make it possible to nest sub-expressions by reference or values depending on whether the expressions are temporaries or not.

We have a class Scalar which is containing a value and a sparse gradient vector. We use expression templates (like Eigen) to prevent the construction and allocation of too many temporaries Scalar objects. Thus we have the class Scalar inheriting from ScalarBase<Scalar> (CRTP).

A binary operation (eg +, *) between objects of type ScalarBase< Left > and ScalarBase< Right > return a ScalarBinaryOp<Left, Right,BinaryOp> object which inherits from ScalarBase< ScalarBinaryOp<Left, Right,BinaryOp> >:

template< typename Left, typename Right >
ScalarBinaryOp< Left, Right, BinaryAdditionOp > operator+(
    const ScalarBase< Left >& left, const ScalarBase< Right >& right )
{
    return ScalarBinaryOp< Left, Right, BinaryAdditionOp >( static_cast< const Left& >( left ),
        static_cast< const Right& >( right ), BinaryAdditionOp{} );
}

ScalarBinaryOp must hold a value or reference to the operands objects of type Left and Right. The type of the holder is defined by template specialization of RefTypeSelector< Expression >::Type.

Currently this is always a const reference. It works at the moment for our test cases but this does not seem correct or safe to hold a reference to temporary subexpressions.

Obviously we also do not want that a Scalar object containing the sparse gradient vector be copied. If x and y are Scalar, the expression x+y should hold const reference to x and y. However if f is a function from Scalar to Scalar. x+f(y) should hold a const reference to x and the value of f(y).

Hence I would like to pass the information about whether subexpressions are temporaries or not. I can add this to the expression type parameters:

ScalarBinaryOp< typename Left, typename Right, typename BinaryOp , bool LeftIsTemporary, bool RightIsTemporary >

and to the RefTypeSelector:

RefTypeSelector< Expression, ExpressionIsTemporary >::Type

But then I would need to define for every binary operators 4 methods:

ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, false > operator+(
    const ScalarBase< Left >& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, true > operator+(
    const ScalarBase< Left >& left, ScalarBase< Right >&& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, false > operator+(
    ScalarBase< Left >&& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, true > operator+(
    ScalarBase< Left >&& left, ScalarBase< Right >&& right )

I would prefer to be able to achieve this with perfect forwarding. However I do not know how I can achieve this here. First I cannot use simple "universal references" because they match almost anything. I guess it might be possible to combine universal references and SFINAE to only allow certain parameter types but I am not sure this is the way to go. Also I would like to know if I could encode the information about whether Left and Right were originally lvalue or rvalue references in the types Left and Right which parameterize the ScalarBinaryOp instead of using the 2 additional booleans parameter and how to retrieve that information.

I have to support gcc 4.8.5 which is mostly c++11 compliant.

Thank you

Aucun commentaire:

Enregistrer un commentaire