samedi 2 février 2019

Different behaviour with calling functions inside and outside objects

Some very weird behaviour is happening with my C++ code. I have a function backward() in my MultiplyOperation class that I call from another object FloatTensor class which has an instance of this MultiplyOperation class.

The result of the output is different if I call three.backOperation->backward(1); and three.backOperation->backward(1); whereas it should be same. Please help.

This is my code:

#include<iostream>
using namespace std; 

class FloatTensor;

class MultiplyOperation{
     public:
        FloatTensor *t1, *t2; 
        float grad = 10;

        MultiplyOperation(FloatTensor* t1, FloatTensor* t2);

        FloatTensor compute();

        void backward(float gradient);
};

class FloatTensor {
    public:
    float val; 
    float grad; 
    MultiplyOperation* backOperation = NULL, *frontOperation = NULL;

    FloatTensor() {
        // default
    }

    FloatTensor(float value) {
        this->val = value;
        this->backOperation = NULL;
    }

    FloatTensor(float value, MultiplyOperation* backOp) {
        this->val = value;
        this->backOperation = backOp;
    }

    void backward(float gradient) {
        this->backOperation->backward(gradient);
    }


    FloatTensor operator * (FloatTensor two) { 
        MultiplyOperation ope(this,&two);
        this->frontOperation = &ope;
        return this->frontOperation->compute();
    }
};


MultiplyOperation::MultiplyOperation(FloatTensor* te1, FloatTensor* te2) {
    this->t1 = te1;
    this->t2 = te2;
}

FloatTensor MultiplyOperation::compute() {
    return FloatTensor(this->t1->val*this->t2->val, this);
}

void MultiplyOperation::backward(float gradient) {

    cout<<this->t2->val<<endl;
}


int main() {

    FloatTensor one(2);
    FloatTensor two(4);

    FloatTensor three = one*two;

    three.backOperation->backward(1); // should be same as output of next line and is 4. (which is correct)
    three.backward(1); // should be same as output of above line but is garbage value -4.12131 

}

Aucun commentaire:

Enregistrer un commentaire