I have a TensorFlow model written in python and I have a c++ application which acquires the training data and passes it to Tensorflow with embedded python Api.
The application has a GUI so it naturally should be multithreaded.
The problem is that although application runs without errors it seems that the model is not training (at least, not in the way I expect), i.e. the trained variable is always some random number close to the initial value. I tried to dump the data and feed it to the python class directly from python and it worked perfectly.
I assume I am doing something wrong with threading. Right?
The definition of the python class looks like this (for simplicity I will assume that we have only one variable):
class TrainingClass:
def __init__(self, params):
# assign params
# build graph, loss etc...
self.variable = tf.Variable(42, dtype=tf.float64)
self.session = tf.Session()
self.session.run([tf.global_variables_initializer()])
def train(self, train_params):
return self.session.run([list_of_optimizers, self.loss], feed_dict = {self.param_placeholders: params})[-1]
def get_trained_variable(self):
return self.session.run(self.variable)
On c++ side things look like this.
class MyCppClass {
std::thread *m_training_thread;
bool thread_is_canceled();
void on_start_button_pressed() {
PyObject * python_class_instance = PyObject_CallFunctionObjArgs(python_class_type, params..., NULL);
m_training_thread = new std::thread(&MyCppClass::cpp_train_function, this, python_class_instance );
}
void cpp_train_function(PyObject * python_class_instance) {
while(!thread_is_canceled()) {
// Acquire lock
PyGILState_STATE state = PyGILState_Ensure();
PyObject * py_list = PyList_New(42);
/* Long and boring code for assigning
* the items of the list
*/
// Pass the py_list to tensorflow
PyObject * py_loss = PyObject_CallMethodObjArgs(python_class_instance, train_function_name, py_list, NULL );
if( py_loss ) {
// Display current loss value to the user
}
// Release lock
PyGILState_Release(state);
}
}
void show_trained_parameters(PyObject * python_class_instance)
{
// Acquire lock
PyGILState_STATE state = PyGILState_Ensure();
PyObject * trained_vars = PyObject_CallMethodObjArgs(python_class_instance, get_trained_variable_py_str , NULL);
// Show the trained variable
// Release lock
PyGILState_Release(state);
}
};
I verified that the py_list is correctly assigned by printing it from python.
OS: Ubuntu 18.04 C++ 11 Python 3.6
Aucun commentaire:
Enregistrer un commentaire