I have a simple class which wraps an array of double as follows:
struct VecExpr {
VecExpr(double *v, size_t n) : x(v), n(n) {}
double *x;
size_t n;
};
Based on this class I implement a vectorial algebra framework which fuses operations. For example, let vx
, vy
, vz
and vr
be classes of type VecExpr
, then I can write:
vr = vx + vy * vz, vr = vr * vr, vr = vr + vx;
and computation are performed internally as if I had written:
for (size_t i = 0; i < vr.size; ++i) {
vr[i] = vx[i] + vy[i] * vz[i]; // i-th iteration, 1st assignment
vr[i] = vr[i] * vr[i]; // i-th iteration, 2nd assignment
vr[i] = vr[i] + vx[i]; // i-th iteration, 3rd assignment
}
As you can see, operations are fused also across different assignments. This is achieved via creation of various expression classes, which overload the comma
operator. Nothing really happens until the last comma in the statement has been processed, at which point evaluation of the whole loop is triggered.
A minimal demo of this framework is given below.
More in details, in pseudocode, at compile time:
- the expression
vr = vx + vy * vz
triggers construction of anAssignExpr
class:AE(vr, vx + vy * vz)
- the expression
vr = vr * vr
triggers construction of anotherAssignExpr
class:AE(vr, vr + vr)
- the comma separating the two above expressions triggers constructions of a
ME
class:ME(AE(vr, vx + vy * vz), AE(vr, vr + vr))
- the expression
vr = vr + vx
triggers construction of anAE
class:AE(vr, vr + vx)
- the 2nd comma triggers constructions of a
MultiExpr
class:ME(ME(AE(vr, vx + vy * vz), AE(vr, vr + vr)), AE(vr, vr + vx))
At this point, since the statement is complete, evaluation starts.
To trigger evaluation I use the destructors of the classes AssignExpr
and MultiExpr
. When any of these class is destructed, if it is the outermost class, then it triggers evaluation.
Is there a better way? How to determine if the class is the outermost or not?
In the demo below, I add to the AssignExpr
and MultiExpr
a boolean flag outer
. When the class is created, this is tentatively initialized to true, but if the class is then used as argument in the constructor of an outer class, the outer class takes care of resetting this flag to false.
Is kind of feel there should be a better way. In the end, everything is known at compile time. Any suggestion?
#include <functional>
#include <iostream>
#include <cassert>
using namespace std;
// forward declaration
template <typename RHS> struct AssignExpr;
template <typename LHS, typename RHS> struct MultiExpr;
struct Size {
Size(size_t s) : sz(s) {}
size_t size() const { return sz; }
size_t sz;
};
struct VecExpr : Size {
VecExpr(double *v, size_t n) : x(v), Size(n) {}
void set(size_t i, double v) const { x[i] = v; }
double operator[](size_t i) const { return x[i]; }
template <typename RHS> auto operator=(const RHS& rhs) const { return AssignExpr<RHS>(*this, rhs); }
double *x;
};
template <template <typename> class Op, typename Arg1, typename Arg2>
struct OperExpr : Size
{
OperExpr(const Arg1& a1, const Arg2& a2) : Size(a1.size()), arg1(a1), arg2(a2) { assert(a1.size() == a2.size()); }
double operator[](size_t i) const { return Op<int>{}(arg1[i], arg2[i]); }
Arg1 arg1;
Arg2 arg2;
};
template <typename RHS>
struct AssignExpr : Size
{
AssignExpr(const VecExpr l, const RHS r)
: Size(l.size()), lhs(l), rhs(r), outer(true) // tentatively set to true
{ assert(l.size() == r.size()); }
AssignExpr(const AssignExpr<RHS>& expr)
: Size(expr.size()), lhs(expr.lhs), rhs(expr.rhs), outer(false) {}
void set(size_t i) const { lhs.set(i, rhs[i]); }
template <class R>
auto operator,(const R& r) { return MultiExpr<AssignExpr<RHS>,R>(*this, r); }
~AssignExpr() {
if(outer) // if outer expression, triggers evaluation
for (size_t i = 0; i < size(); ++i)
set(i);
}
VecExpr lhs;
RHS rhs;
mutable bool outer;
};
template <typename LHS, typename RHS>
struct MultiExpr : Size
{
MultiExpr(const LHS& l, const RHS& r)
: Size(l.size()), lhs(l), rhs(r), outer(true) // tentatively set to true
{ l.outer = r.outer = false; assert(l.size() == r.size()); } // reset flag for arguments
MultiExpr(const MultiExpr<LHS, RHS>& expr)
: Size(expr.size()), lhs(expr.lhs), rhs(expr.rhs), outer(false) {}
void set(size_t i) const { lhs.set(i); rhs.set(i); } // evaluates in reverse order
template <class R>
auto operator,(const R& r) { return MultiExpr<MultiExpr<LHS, RHS>, R>(*this, r); }
~MultiExpr() {
if(outer) // if outer expression, triggers evaluation
for (size_t i = 0; i < size(); ++i)
set(i);
}
LHS lhs;
RHS rhs;
mutable bool outer;
};
template <typename Arg1, typename Arg2>
auto operator*(const Arg1& arg1, const Arg2& arg2)
{
return OperExpr<multiplies, Arg1, Arg2>(arg1, arg2);
}
template <typename Arg1, typename Arg2>
auto operator+(const Arg1& arg1, const Arg2& arg2)
{
return OperExpr<plus, Arg1, Arg2>(arg1, arg2);
}
void demo(size_t n, double *x, double *y, double *z, double *r)
{
VecExpr vx(x, n), vy(y, n), vz(z, n), vr(r, n);
vr = vx + vy * vz, vr = vr * vr, vr = vr + vx;
}
int main()
{
double x[] = {2, 3, 4}, y[] = {3, 4, 5}, z[] = {4, 5, 6}, r[3];
demo(3, x, y, z, r);
for (auto d : r)
cout << d << "\n";
return 0;
}
Aucun commentaire:
Enregistrer un commentaire