lundi 24 octobre 2022

Cannot assign contents of list attribute in Pybind11 defined class

I have a sparse matrix implementation in C++, and I used pybind11 to expose it to python. Here is the problem:

>>> D1 = phc.SparseMatrix(3, [[0],[1],[2]])
>>> D1.cData
[[0], [1], [2]]
>>> D1.cData[1] = [1,2]
>>> D1.cData
[[0], [1], [2]] #Should be [[0], [1,2], [2]]

In python, I cannot change the contents of the SparseMatrix.cData attribute with the assignment operator. I can change the entire list with D1.cData = [[1],[2],[3]]. This behavior is bewildering to me. D1.cData is just a list, so I would expect that the above code would work.

I suspect it has something to do with my pybind11 code since this behavior is not present in python-defined custom classes. But I have no idea what is wrong (I am a novice programmer). Here is the source code info:

Python Bindings

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

#include <SparseMatrix.h>

namespace phc = ph_computation;
using SparseMatrix = phc::SparseMatrix;
using Column = phc::Column;
using CData = phc::CData;

PYBIND11_MODULE(ph_computations, m)
{
    m.doc() = "ph_computations python bindings";

    using namespace pybind11::literals;

    m.def("add_cols", &phc::add_cols);//begin SparseMatrix.h

    py::class_<SparseMatrix>(m, "SparseMatrix")
            .def(py::init<size_t, CData>())
            .def(py::init<std::string>())
            .def_readwrite("n_rows", &SparseMatrix::n_rows)
            .def_readwrite("n_cols", &SparseMatrix::n_cols)
            .def_readwrite("cData", &SparseMatrix::cData)
            .def("__neq__", &SparseMatrix::operator!=)
            .def("__eq__", &SparseMatrix::operator==)
            .def("__add__", &SparseMatrix::operator+)
            .def("__mul__", &SparseMatrix::operator*)
            .def("transpose", &SparseMatrix::transpose)
            .def("__str__", &SparseMatrix::print)
            .def("save", &SparseMatrix::save)
            ;

    m.def("identity", &phc::make_identity);
    m.def("matching_pivots", &phc::matching_pivots);//end SparseMatrix.h
}

SparseMatrix.h

#pragma once

#include <iostream>
#include <sstream>
#include <fstream>
#include <string>
#include <iterator>
#include <algorithm>
#include <vector>
#include <stdexcept>

namespace ph_computation{

using Int = int;

using Column = std::vector<Int>;//a Column is represented by a vector of indices
using CData = std::vector<Column>;//a matrix is represented by a vector of Columns

//Add columns in Z2
Column add_cols(const Column& c1, const Column& c2);

struct SparseMatrix
{
    size_t n_rows{0};
    size_t n_cols{0};
    CData cData;

    SparseMatrix()=default;

    SparseMatrix(size_t n_rows_, CData cData_):
    n_rows(n_rows_), n_cols(cData_.size()), cData(cData_){}

    SparseMatrix(std::string path);

    bool operator!=(const SparseMatrix &other) const;

    bool operator==(const SparseMatrix &other) const;

    SparseMatrix operator+(const SparseMatrix &other) const;

    SparseMatrix operator*(const SparseMatrix &other) const;

    void transpose();

    void print() const;

    void save(std::string path);
};

SparseMatrix make_identity(size_t n_cols_);

bool matching_pivots(const SparseMatrix& a, const SparseMatrix& b);

}

SparseMatrix.cpp (you probably don't need this)

#include <SparseMatrix.h>

namespace ph_computation {

Column add_cols(const Column& c1, const Column& c2){
    Column c3;
    int idx1{0};
    int idx2{0};
    while(idx1 < c1.size() && idx2 < c2.size()){
        if(c1[idx1] < c2[idx2]){
            c3.push_back(c1[idx1]);
            ++idx1;
        }
        else if(c1[idx1] > c2[idx2]){
            c3.push_back(c2[idx2]);
            ++idx2;
        }
        else {
            ++idx1;
            ++idx2;
        }
    }
    if (idx1 < c1.size()){
        c3.insert(c3.end(), std::next(c1.begin(), idx1), c1.end());
    }
    else if (idx2 < c2.size()){
        c3.insert(c3.end(), std::next(c2.begin(), idx2), c2.end());
    }

    return c3;
}

SparseMatrix make_identity(size_t n_cols_){
    CData cData_(n_cols_);
    for (int j = 0; j < n_cols_; ++j){
        cData_[j] = {j};
    }
    return SparseMatrix(n_cols_, cData_);
}

SparseMatrix::SparseMatrix(std::string path){
    std::fstream f_in;
    f_in.open(path, std::ios::in);
    if(f_in.is_open()){
        std::string n_rows_line;
        std::getline(f_in, n_rows_line); //first line of file contains number of rows
        n_rows = std::stoi(n_rows_line);

        std::string n_cols_line;
        std::getline(f_in, n_cols_line); //second line of file contains number of cols
        n_cols = std::stoi(n_cols_line);

        CData cData_(n_cols);
        cData = cData_;

        std::string line;
        int j = 0;
        int nnz, data;
        while (std::getline(f_in, line)){
            std::stringstream line_str = std::stringstream(line);
            while (line_str >> nnz){
                Column col_j(nnz);
                for (int i =0; i < nnz; ++i){
                    line_str >> data;
                    col_j[i] = data;
                }
                cData[j] = col_j;
            }
            ++j;
        }
        f_in.close();
    }
    else{
        throw std::runtime_error("File did not open.");
    }
}

bool SparseMatrix::operator!=(const SparseMatrix &other) const{
    if (n_rows != other.n_rows || cData != other.cData){
        return true;
    }
    return false;
}

bool SparseMatrix::operator==(const SparseMatrix &other) const{
    return !(*this != other);
    }

SparseMatrix SparseMatrix::operator+(const SparseMatrix &other) const{
        if (n_rows != other.n_rows || n_cols != other.n_cols){
            throw std::invalid_argument("Matrices must have same dimension to add.");
        }

        CData ans_cData;
        for (int j = 0; j < n_cols; ++j){
            ans_cData.push_back(add_cols(cData[j], other.cData[j]));
        }

        return SparseMatrix(n_rows, ans_cData);
    }

SparseMatrix SparseMatrix::operator*(const SparseMatrix &other) const{
        if(n_cols != other.n_rows){
            throw std::invalid_argument("Matrices must have compatible dimensions.");
        }

        size_t ans_rows = n_rows;
        CData ans_cData(other.n_cols);
        SparseMatrix ans(ans_rows, ans_cData);

        for(int j =0; j<ans.n_cols; ++j){
            for(int idx : other.cData[j]){
                ans.cData[j] = add_cols(ans.cData[j], cData[idx]);
            }
        }

        return ans;
    }

void SparseMatrix::transpose(){
        CData cData_T(n_rows);
        for(int j =0; j<n_cols; ++j){
            if(!cData[j].empty()){
                for(int x: cData[j]){
                    cData_T[x].push_back(j);
                }
            }
        }
        cData = cData_T;
        n_rows = n_cols;
        n_cols = cData.size();
    }

void SparseMatrix::print() const{
        for (int i = 0; i < n_rows; ++i){
            for (int j = 0; j < n_cols; ++j){
                if (cData[j].empty())
                    {std::cout << " 0";}
                else if (std::binary_search(cData[j].begin(), cData[j].end(), i))//Assumes row indices
                    {std::cout << " 1";}                                        //are ordered
                else
                    {std::cout << " 0";}
                if (n_cols-1 == j)
                    {std::cout << " \n";}
            }
        }
    }

void SparseMatrix::save(std::string path){
        std::fstream f_out;
        f_out.open(path, std::ios::out);
        if(f_out.is_open()){
            f_out << n_rows << "\n";
            f_out << n_cols << "\n";
            for(int j = 0; j < n_cols; ++j){
                int col_j_sz = cData[j].size();
                f_out << col_j_sz;
                for(int i = 0; i < col_j_sz; ++i){
                    f_out << " " << cData[j][i];
                }
                f_out << "\n";
            }
            f_out.close();
        }
        else{
            throw std::runtime_error("File did not open.");
        }
    }

bool matching_pivots(const SparseMatrix& a, const SparseMatrix& b){
    if(a.n_rows != b.n_rows || a.n_cols != b.n_cols){
        throw std::invalid_argument("Input matrices must have the same size.");
    }

    for (int j = 0; j<a.n_cols; ++j){
        bool a_j_empty = a.cData[j].empty();
        bool b_j_empty = b.cData[j].empty();
        if (a_j_empty != b_j_empty){
            return false;
        }
        else if (!a_j_empty){
            if(a.cData[j].back() != b.cData[j].back()){
                return false;
            }
        }
    }
    return true;
}

} // namespace ph_computation

Aucun commentaire:

Enregistrer un commentaire