vendredi 28 décembre 2018

comma operator overload: how to trigger an action after the last comma in a statement has been processed

I have a simple class which wraps an array of double as follows:

struct VecExpr {
    VecExpr(double *v, size_t n) : x(v), n(n) {}
    double *x;
    size_t n;
};

Based on this class I implement a vectorial algebra framework which fuses operations. For example, let vx, vy, vz and vr be classes of type VecExpr, then I can write:

vr = vx + vy * vz, vr = vr * vr, vr = vr + vx;

and computation are performed internally as if I had written:

for (size_t i = 0; i < vr.size; ++i) {
   vr[i] = vx[i] + vy[i] * vz[i];   // i-th iteration, 1st assignment
   vr[i] = vr[i] * vr[i];           // i-th iteration, 2nd assignment
   vr[i] = vr[i] + vx[i];           // i-th iteration, 3rd assignment
}

As you can see, operations are fused also across different assignments. This is achieved via creation of various expression classes, which overload the comma operator. Nothing really happens until the last comma in the statement has been processed, at which point evaluation of the whole loop is triggered.

A minimal demo of this framework is given below.

More in details, in pseudocode, at compile time:

  • the expression vr = vx + vy * vz triggers construction of an AssignExprclass: AE(vr, vx + vy * vz)
  • the expression vr = vr * vr triggers construction of another AssignExpr class: AE(vr, vr + vr)
  • the comma separating the two above expressions triggers constructions of a ME class: ME(AE(vr, vx + vy * vz), AE(vr, vr + vr))
  • the expression vr = vr + vx triggers construction of an AEclass: AE(vr, vr + vx)
  • the 2nd comma triggers constructions of a MultiExpr class: ME(ME(AE(vr, vx + vy * vz), AE(vr, vr + vr)), AE(vr, vr + vx))

At this point, since the statement is complete, evaluation starts.

To trigger evaluation I use the destructors of the classes AssignExpr and MultiExpr. When any of these class is destructed, if it is the outermost class, then it triggers evaluation.

Is there a better way? How to determine if the class is the outermost or not?

In the demo below, I add to the AssignExpr and MultiExpr a boolean flag outer. When the class is created, this is tentatively initialized to true, but if the class is then used as argument in the constructor of an outer class, the outer class takes care of resetting this flag to false.

Is kind of feel there should be a better way. In the end, everything is known at compile time. Any suggestion?

#include <functional>
#include <iostream>
#include <cassert>
using namespace std;

// forward declaration
template <typename RHS> struct AssignExpr;
template <typename LHS, typename RHS> struct MultiExpr;

struct Size {
    Size(size_t s) : sz(s) {}
    size_t size() const { return sz; }
    size_t sz;
};

struct VecExpr : Size {
    VecExpr(double *v, size_t n) : x(v), Size(n) {}

    void set(size_t i, double v) const { x[i] = v; }
    double operator[](size_t i) const { return x[i]; }

    template <typename RHS> auto operator=(const RHS& rhs) const { return AssignExpr<RHS>(*this, rhs); }

    double *x;
};

template <template <typename> class Op, typename Arg1, typename Arg2>
struct OperExpr : Size
{
    OperExpr(const Arg1& a1, const Arg2& a2) : Size(a1.size()), arg1(a1), arg2(a2) { assert(a1.size() == a2.size()); }
    double operator[](size_t i) const { return Op<int>{}(arg1[i], arg2[i]); }

    Arg1 arg1;
    Arg2 arg2;
};


template <typename RHS>
struct AssignExpr : Size
{
    AssignExpr(const VecExpr l, const RHS r)
        : Size(l.size()), lhs(l), rhs(r), outer(true)  // tentatively set to true
    { assert(l.size() == r.size()); }
    AssignExpr(const AssignExpr<RHS>& expr)
        : Size(expr.size()), lhs(expr.lhs), rhs(expr.rhs), outer(false) {}

    void set(size_t i) const { lhs.set(i, rhs[i]); }

    template <class R>
    auto operator,(const R& r) { return MultiExpr<AssignExpr<RHS>,R>(*this, r); }

    ~AssignExpr() {
       if(outer)  // if outer expression, triggers evaluation
          for (size_t i = 0; i < size(); ++i)
             set(i);
     }

    VecExpr lhs;
    RHS rhs;
    mutable bool outer;
};

template <typename LHS, typename RHS>
struct MultiExpr : Size
{
    MultiExpr(const LHS& l, const RHS& r)
        : Size(l.size()), lhs(l), rhs(r), outer(true)  // tentatively set to true
    { l.outer = r.outer = false;  assert(l.size() == r.size()); } // reset flag for arguments
    MultiExpr(const MultiExpr<LHS, RHS>& expr)
        : Size(expr.size()), lhs(expr.lhs), rhs(expr.rhs), outer(false) {}

    void set(size_t i) const { lhs.set(i); rhs.set(i); }  // evaluates in reverse order

    template <class R>
    auto operator,(const R& r) { return MultiExpr<MultiExpr<LHS, RHS>, R>(*this, r); }

    ~MultiExpr() {
       if(outer)  // if outer expression, triggers evaluation
          for (size_t i = 0; i < size(); ++i)
             set(i);
    }


    LHS lhs;
    RHS rhs;
    mutable bool outer;
};

template <typename Arg1, typename Arg2>
auto operator*(const Arg1& arg1, const Arg2& arg2)
{
    return OperExpr<multiplies, Arg1, Arg2>(arg1, arg2);
}

template <typename Arg1, typename Arg2>
auto operator+(const Arg1& arg1, const Arg2& arg2)
{
    return OperExpr<plus, Arg1, Arg2>(arg1, arg2);
}

void demo(size_t n, double *x, double *y, double *z, double *r)
{
    VecExpr vx(x, n), vy(y, n), vz(z, n), vr(r, n);
    vr = vx + vy * vz, vr = vr * vr, vr = vr + vx;
}

int main()
{
    double x[] = {2, 3, 4}, y[] = {3, 4, 5}, z[] = {4, 5, 6}, r[3];
    demo(3, x, y, z, r);
    for (auto d : r)
        cout << d << "\n";
    return 0;
}

Aucun commentaire:

Enregistrer un commentaire