Some very weird behaviour is happening with my C++ code. I have a function backward() in my MultiplyOperation class that I call from another object FloatTensor class which has an instance of this MultiplyOperation class.
The result of the output is different if I call three.backOperation->backward(1);
and three.backOperation->backward(1);
whereas it should be same. Please help.
This is my code:
#include<iostream>
using namespace std;
class FloatTensor;
class MultiplyOperation{
public:
FloatTensor *t1, *t2;
float grad = 10;
MultiplyOperation(FloatTensor* t1, FloatTensor* t2);
FloatTensor compute();
void backward(float gradient);
};
class FloatTensor {
public:
float val;
float grad;
MultiplyOperation* backOperation = NULL, *frontOperation = NULL;
FloatTensor() {
// default
}
FloatTensor(float value) {
this->val = value;
this->backOperation = NULL;
}
FloatTensor(float value, MultiplyOperation* backOp) {
this->val = value;
this->backOperation = backOp;
}
void backward(float gradient) {
this->backOperation->backward(gradient);
}
FloatTensor operator * (FloatTensor two) {
MultiplyOperation ope(this,&two);
this->frontOperation = &ope;
return this->frontOperation->compute();
}
};
MultiplyOperation::MultiplyOperation(FloatTensor* te1, FloatTensor* te2) {
this->t1 = te1;
this->t2 = te2;
}
FloatTensor MultiplyOperation::compute() {
return FloatTensor(this->t1->val*this->t2->val, this);
}
void MultiplyOperation::backward(float gradient) {
cout<<this->t2->val<<endl;
}
int main() {
FloatTensor one(2);
FloatTensor two(4);
FloatTensor three = one*two;
three.backOperation->backward(1); // should be same as output of next line and is 4. (which is correct)
three.backward(1); // should be same as output of above line but is garbage value -4.12131
}
Aucun commentaire:
Enregistrer un commentaire