vendredi 20 avril 2018

Multithreaded Binary Tree C++

Say you are building a binary tree on some simple data. What I would like to know is how to use C++11 multithreading to build the tree.

I want to give the first left branch to one thread and the first right branch to another thread; therefore, there is no interaction between the two branches in the building process. After the threads return, I want to join the two branches manually at the root.

I don't want the threads to be visible to the user or in the main. Here is what I have so far (not multithreaded):

#pragma once

#include <vector>
#include <iostream>
#include <string>
#include <cmath>
#include <memory>
#include <thread>
#include <random>

using namespace std;

class Tree {
private:
    vector<float> & pts;
    int leaves;
    int max_depth;
    int depth;
    vector<int> indices;
    unique_ptr<Tree> L;
    unique_ptr<Tree> R;

    void build_bin();
    void init();
public:
    Tree(vector<float> & pts) : pts(pts) {};
    Tree(const Tree& other) = delete; //non-copy move only semantic
    Tree& operator=(const Tree& rhs) = delete; //non-copy move only semantic

    void build();
};

void Tree::build() {
    this->init();
    cout << "init done" << endl;
    this->build_bin();
    cout << "build_bin done" << endl;
}

void Tree::init() {
    for (size_t i = 0; i < this->pts.size(); ++i) {
        this->indices.push_back((int)i);
    }
    this->max_depth = 1 + log(this->pts.size()) / log(2);
    this->depth = 0;
    return;
}

void Tree::build_bin() {
    int nb = this->indices.size();
    //cout << "At depth " << this->depth << ", nb of pts is " << nb << endl;

    if (nb <= 1) {
        return;
    }
    if (this->depth == this->max_depth) {
        return;
    }

    this->leaves = nb;

    //Compute center
    float mean;
    mean = pts.at(this->indices.at(0));
    for (size_t i = 1; i < nb; ++i) {
        mean = mean + (pts.at(this->indices.at(i)) - mean) / (i + 1);
    }

    vector<int> L_indices;
    vector<int> R_indices;

    for (size_t i = 0; i < this->indices.size(); ++i) {
        float pt = pts.at(indices.at(i));
        if (pt > mean) {
            L_indices.push_back(indices.at(i));
        }
        else
        {
            R_indices.push_back(indices.at(i));
        }
    }
    indices.clear(); //free memory

    if (L_indices.size() > 0) {
        this->L = unique_ptr<Tree>(new Tree(this->pts));
        this->L->indices = L_indices;
        L_indices.clear();
        this->L->max_depth = this->max_depth;
        this->L->depth = this->depth + 1;
    }
    else {
        this->L = nullptr;
        L_indices.clear();
    }

    if (R_indices.size() > 0) {
        this->R = unique_ptr<Tree>(new Tree(this->pts));
        this->R->indices = R_indices;
        R_indices.clear();
        this->R->max_depth = this->max_depth;
        this->R->depth = this->depth + 1;
    }
    else {
        this->R = nullptr;
        R_indices.clear();
    }

    if (this->L != nullptr) {
        this->L->build_bin();
    }
    if (this->R != nullptr) {
        this->R->build_bin();
    }
    return;
}

int main(int argc, char *argv[])
{
    const int nrolls = 107090; //M
    std::default_random_engine generator;
    std::normal_distribution<float> distribution(5.0, 10.0);

    vector<float> data;
    float pt;
    float save;
    for (size_t i = 0; i < nrolls; ++i) {
        pt = distribution(generator);
        data.push_back(pt);
    }

    cout << "Number of elements in obj " << data.size() << " ." << endl;
    Tree tree(data);
    cout << "Create Tree(data)." << endl;
    tree.build();
    cout << "tree.build()." << endl;
    cin.ignore();
    return 0;
}

Aucun commentaire:

Enregistrer un commentaire