import numpy as np from scipy.sparse import csr_matrix def delete_from_csr(mat, row_indices=[], col_indices=[]): """ Remove the rows (denoted by ``row_indices``) and columns (denoted by ``col_indices``) from the CSR sparse matrix ``mat``. WARNING: Indices of altered axes are reset in the returned matrix """ if not isinstance(mat, csr_matrix): raise ValueError("works only for CSR format -- use .tocsr() first") rows = [] cols = [] if row_indices: rows = list(row_indices) if col_indices: cols = list(col_indices) if len(rows) > 0 and len(cols) > 0: row_mask = np.ones(mat.shape[0], dtype=bool) row_mask[rows] = False col_mask = np.ones(mat.shape[1], dtype=bool) col_mask[cols] = False return mat[row_mask][:,col_mask] elif len(rows) > 0: mask = np.ones(mat.shape[0], dtype=bool) mask[rows] = False return mat[mask] elif len(cols) > 0: mask = np.ones(mat.shape[1], dtype=bool) mask[cols] = False return mat[:,mask] else: return mat
Here is what the above code is Doing:
1. It checks if the matrix is in CSR format. If not, it raises an error.
2. It creates two lists, one for rows and one for columns.
3. It checks if the lists are empty. If not, it creates two masks, one for rows and one for columns.
4. It returns the matrix with the rows and columns removed.