#![cfg_attr(is_nightly, feature(portable_simd))]
use bitvec::order::Lsb0;
use bitvec::slice::BitSlice;
use bitvec::store::BitStore;
use bitvec::vec::BitVec;
use cfg_if::cfg_if;
use rand::distributions::Standard;
use rand::prelude::Distribution;
use rand::Rng;
#[cfg(feature = "rayon")]
use rayon::iter::IndexedParallelIterator;
#[cfg(feature = "rayon")]
use rayon::slice::{ParallelSlice, ParallelSliceMut};
use serde::{Deserialize, Serialize};
use std::fmt::{Binary, Debug, Formatter};
use std::ops::{BitAnd, BitXor, Not, Range};
use std::slice::{ChunksExact, ChunksExactMut};
#[cfg(is_nightly)]
mod portable_transpose;
mod simple;
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
mod sse2_transpose;
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct BitMatrix<T: Storage> {
rows: usize,
cols: usize,
data: Vec<T>,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct BitMatrixView<'a, T> {
rows: usize,
cols: usize,
data: &'a [T],
}
#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct BitMatrixViewMut<'a, T> {
rows: usize,
cols: usize,
data: &'a mut [T],
}
pub trait Storage:
bytemuck::Pod + BitXor<Output = Self> + BitAnd<Output = Self> + Not<Output = Self> + Send + Sync
{
const BITS: usize;
fn zero() -> Self;
}
impl<T: Storage> BitMatrix<T> {
pub fn new(rows: usize, cols: usize) -> Self {
Self::zeros(rows, cols)
}
pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Self {
Self { rows, cols, data }
}
pub fn zeros(rows: usize, cols: usize) -> Self {
check_dim::<T>(rows, cols);
Self {
data: vec![T::zero(); rows * cols / T::BITS],
rows,
cols,
}
}
pub fn random<R: Rng>(rng: R, rows: usize, cols: usize) -> Self
where
Standard: Distribution<T>,
{
check_dim::<T>(rows, cols);
let data = rng
.sample_iter(Standard)
.take(rows * cols / T::BITS)
.collect();
Self { data, rows, cols }
}
pub fn view(&self) -> BitMatrixView<'_, T> {
BitMatrixView {
rows: self.rows,
cols: self.cols,
data: self.data.as_slice(),
}
}
pub fn view_mut(&mut self) -> BitMatrixViewMut<'_, T> {
BitMatrixViewMut {
rows: self.rows,
cols: self.cols,
data: self.data.as_mut_slice(),
}
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn dim(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn storage_len(&self) -> usize {
self.data.len()
}
pub fn into_vec(self) -> Vec<T> {
self.data
}
pub fn iter_rows(&self) -> Rows<'_, T> {
Rows::new(self.view())
}
pub fn iter_raw_rows(&self) -> RawRows<'_, T> {
RawRows::new(self.view())
}
pub fn iter_raw_rows_mut(&mut self) -> RawRowsMut<'_, T> {
RawRowsMut::new(self.view_mut())
}
#[cfg(feature = "rayon")]
pub fn par_iter_raw_rows(&self) -> impl IndexedParallelIterator<Item = &[T]> {
assert_eq!(
0,
self.cols % T::BITS,
"cols must be divisable by bits for raw iterator"
);
self.data.par_chunks_exact(self.cols / T::BITS)
}
#[cfg(feature = "rayon")]
pub fn par_iter_raw_rows_mut(&mut self) -> impl IndexedParallelIterator<Item = &mut [T]> {
assert_eq!(
0,
self.cols % T::BITS,
"cols must be divisable by bits for raw iterator"
);
self.data.par_chunks_exact_mut(self.cols / T::BITS)
}
pub fn scalar_and(&self, rhs: &Self) -> Self {
assert_eq!(
self.dim(),
rhs.dim(),
"Dimensions must be identical for scalar_and"
);
let data = self
.data
.iter()
.zip(&rhs.data)
.map(|(a, b)| *a & *b)
.collect();
Self {
rows: self.rows,
cols: self.cols,
data,
}
}
}
impl<T> BitMatrix<T>
where
T: Storage + BitStore<Unalias = T>,
{
pub fn from_bits(bits: &BitSlice<T, Lsb0>, rows: usize, cols: usize) -> Self {
assert_eq!(
bits.len() % T::BITS,
0,
"Length of bits must be multiple of T::BITS"
);
assert_eq!(bits.len(), rows * cols, "bits.len() != rows * cols");
let data = bits.to_bitvec().into_vec();
Self { rows, cols, data }
}
pub fn identity(size: usize) -> Self {
check_dim::<T>(size, size);
let mut bv: BitVec<T> = BitVec::repeat(false, size * size);
let mut idx = 0;
while idx < size * size {
bv.set(idx, true);
idx += size + 1;
}
Self::from_vec(bv.into_vec(), size, size)
}
pub fn mat_mul(self, rhs: &Self) -> Self {
assert_eq!(
self.cols, rhs.rows,
"Illegal dimensions for matrix multiplication"
);
let dotp = |l_row: &BitSlice<T>, r_row| -> bool {
let and = l_row.to_bitvec() & r_row;
and.iter().by_vals().reduce(BitXor::bitxor).unwrap()
};
let rhs = rhs.view().transpose();
let bits = self
.iter_rows()
.flat_map(|l_row| rhs.iter_rows().map(|r_row| dotp(l_row, r_row)));
let mut bv: BitVec<T> = BitVec::with_capacity(self.rows * rhs.rows);
bv.extend(bits);
Self::from_vec(bv.into_vec(), self.rows, rhs.rows)
}
pub fn into_bitvec(self) -> BitVec<T> {
BitVec::from_vec(self.data)
}
}
impl<'a, T: Storage> BitMatrixView<'a, T> {
pub fn from_slice(data: &'a [T], rows: usize, cols: usize) -> Self {
assert_eq!(
data.len() * T::BITS,
rows * cols,
"data.len() does not match rows * cols"
);
Self { rows, cols, data }
}
pub fn fast_transpose(&self) -> BitMatrix<T> {
cfg_if! {
if #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] {
let transposed = sse2_transpose::transpose(self.data, self.rows, self.cols);
} else if #[cfg(is_nightly)] {
let transposed = portable_transpose::transpose(self.data, self.rows, self.cols);
} else {
use std::compile_error;
compile_error!("Target must either be x86_64 with sse2 enabled or crate \
feature \"portable_transpose\" must be enabled (requires nightly)")
}
}
BitMatrix::from_vec(transposed, self.cols, self.rows)
}
#[inline]
pub fn raw_row(&self, row: usize) -> Option<&'a [T]> {
let idx = raw_row_idx::<T>(row, self.cols);
self.data.get(idx)
}
fn can_do_sse_trans(&self) -> bool {
self.rows % 8 == 0 && self.cols % 8 == 0 && self.rows >= 16 && self.cols >= 16
}
}
impl<'a, T: BitStore<Unalias = T>> BitMatrixView<'a, T> {
pub fn as_bitslice(&self) -> &'a BitSlice<T> {
BitSlice::from_slice(self.data)
}
#[inline]
pub fn row(&self, row: usize) -> Option<&'a BitSlice<T>> {
let data = self.as_bitslice();
let start_idx = row * self.cols;
let end_idx = (row + 1) * self.cols;
data.get(start_idx..end_idx)
}
}
impl<'a, T: Storage + BitStore<Unalias = T>> BitMatrixView<'a, T> {
pub fn transpose(&self) -> BitMatrix<T> {
let transposed = if self.can_do_sse_trans()
&& (cfg!(is_nightly) || cfg!(all(target_arch = "x86_64", target_feature = "sse2")))
{
cfg_if! {
if #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] {
sse2_transpose::transpose(self.data, self.rows, self.cols)
} else if #[cfg(is_nightly)] {
portable_transpose::transpose(self.data, self.rows, self.cols)
} else {
simple::transpose(self.data, self.rows, self.cols)
}
}
} else {
simple::transpose(self.data, self.rows, self.cols)
};
BitMatrix::from_vec(transposed, self.cols, self.rows)
}
}
impl<'a, T: Storage> BitMatrixViewMut<'a, T> {
pub fn raw_row_mut(&'a mut self, row: usize) -> Option<&'a mut [T]> {
let idx = raw_row_idx::<T>(row, self.cols);
self.data.get_mut(idx)
}
}
#[inline]
fn raw_row_idx<T: Storage>(row: usize, cols: usize) -> Range<usize> {
assert_eq!(
0,
cols % T::BITS,
"cols must be divisable by T::BITS for raw_row. Use row() instead."
);
let cols_el = cols / T::BITS;
let start_idx = row * cols_el;
let end_idx = (row + 1) * cols_el;
start_idx..end_idx
}
impl<T: Storage> Not for BitMatrix<T> {
type Output = Self;
fn not(mut self) -> Self::Output {
self.data.iter_mut().for_each(|el| *el = !*el);
self
}
}
impl<T: Storage> BitXor for BitMatrix<T> {
type Output = BitMatrix<T>;
fn bitxor(mut self, rhs: Self) -> Self::Output {
assert_eq!(
self.dim(),
rhs.dim(),
"BitXor on matrices with different dimensions"
);
self.data.iter_mut().zip(rhs.data).for_each(|(a, b)| {
*a = *a ^ b;
});
self
}
}
impl<T: Storage> BitXor<&Self> for BitMatrix<T> {
type Output = BitMatrix<T>;
fn bitxor(mut self, rhs: &Self) -> Self::Output {
assert_eq!(
self.dim(),
rhs.dim(),
"BitXor on matrices with different dimensions"
);
self.data.iter_mut().zip(&rhs.data).for_each(|(a, b)| {
*a = *a ^ *b;
});
self
}
}
impl<T: Binary + Storage> Binary for BitMatrix<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let fmt_bin: Vec<_> = self.data.iter().map(|el| format!("{el:b}")).collect();
f.debug_struct("BitMatrix")
.field("rows", &self.rows)
.field("cols", &self.cols)
.field("data", &fmt_bin)
.finish()
}
}
impl Storage for u8 {
const BITS: usize = 8;
fn zero() -> Self {
0
}
}
impl Storage for u16 {
const BITS: usize = 16;
fn zero() -> Self {
0
}
}
impl Storage for u32 {
const BITS: usize = 32;
fn zero() -> Self {
0
}
}
impl Storage for u64 {
const BITS: usize = 64;
fn zero() -> Self {
0
}
}
impl Storage for u128 {
const BITS: usize = 128;
fn zero() -> Self {
0
}
}
#[derive(Clone, Debug)]
pub struct Rows<'a, T> {
view: BitMatrixView<'a, T>,
row: usize,
}
impl<'a, T> Rows<'a, T> {
pub fn new(view: BitMatrixView<'a, T>) -> Self {
Self { view, row: 0 }
}
}
impl<'a, T: BitStore<Unalias = T>> Iterator for Rows<'a, T> {
type Item = &'a BitSlice<T>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let ret = self.view.row(self.row);
self.row += 1;
ret
}
}
#[derive(Clone, Debug)]
pub struct RawRows<'a, T> {
chunks: ChunksExact<'a, T>,
}
impl<'a, T: Storage> RawRows<'a, T> {
pub fn new(view: BitMatrixView<'a, T>) -> Self {
assert_eq!(
0,
view.cols % T::BITS,
"cols of BitMatrix must be multiple of T::BITS for raw rows iterator"
);
let chunks = view.data.chunks_exact(view.cols / T::BITS);
Self { chunks }
}
}
impl<'a, T: Storage> Iterator for RawRows<'a, T> {
type Item = &'a [T];
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.chunks.next()
}
}
#[derive(Debug)]
pub struct RawRowsMut<'a, T> {
chunks: ChunksExactMut<'a, T>,
}
impl<'a, T: Storage> RawRowsMut<'a, T> {
pub fn new(view: BitMatrixViewMut<'a, T>) -> Self {
assert_eq!(
0,
view.cols % T::BITS,
"cols of BitMatrix must be multiple of T::BITS for raw rows iterator"
);
let chunks = view.data.chunks_exact_mut(view.cols / T::BITS);
Self { chunks }
}
}
impl<'a, T: Storage> Iterator for RawRowsMut<'a, T> {
type Item = &'a mut [T];
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.chunks.next()
}
}
fn check_dim<T: Storage>(rows: usize, cols: usize) {
assert_eq!(
(rows * cols) % T::BITS,
0,
"rows * cols must be divisable by T::BITS"
);
}
#[cfg(test)]
mod tests {
use crate::BitMatrix;
use bitvec::vec::BitVec;
use ndarray::Array2;
use num_traits::{One, Zero};
use rand::{thread_rng, Rng};
use std::fmt::{Debug, Formatter};
use std::ops::{Add, Div, Mul, Sub};
#[derive(Copy, Clone)]
struct Z2(u8);
#[test]
fn mul() {
let id: BitMatrix<u8> = BitMatrix::identity(128);
let other = BitMatrix::random(thread_rng(), 128, 256);
let mul = id.mat_mul(&other);
assert_eq!(other, mul)
}
#[test]
fn mul_ndarray() {
let cols = 128;
let rows = 128;
let mut rng = thread_rng();
let nd_arr1: Vec<_> = (0..rows * cols).map(|_| Z2(rng.gen_range(0..2))).collect();
let bitmat1 = BitMatrix::from_bits(
&nd_arr1.iter().map(|bit| bit.0 == 1).collect::<BitVec<u8>>(),
rows,
cols,
);
let nd_arr1 = Array2::from_shape_vec((rows, cols), nd_arr1).unwrap();
let nd_arr2: Vec<_> = (0..rows * cols).map(|_| Z2(rng.gen_range(0..2))).collect();
let bitmat2 = BitMatrix::from_bits(
&nd_arr2.iter().map(|bit| bit.0 == 1).collect::<BitVec<u8>>(),
rows,
cols,
);
let nd_arr2 = Array2::from_shape_vec((rows, cols), nd_arr2).unwrap();
let res_nd_arr = nd_arr1.dot(&nd_arr2);
let res_bit_mat = bitmat1.mat_mul(&bitmat2);
for (nd_row, bit_mat_row) in res_nd_arr.rows().into_iter().zip(res_bit_mat.iter_rows()) {
for (nd_bit, bit_mat_bit) in nd_row.iter().zip(bit_mat_row) {
assert_eq!(
nd_bit.0 == 1,
*bit_mat_bit,
"BitMatrix::mat_mult differs from nd_array"
);
}
}
}
impl Add for Z2 {
type Output = Z2;
fn add(self, rhs: Self) -> Self::Output {
Self(self.0 + rhs.0 % 2)
}
}
impl Sub for Z2 {
type Output = Z2;
fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 + rhs.0 % 2)
}
}
impl Mul for Z2 {
type Output = Z2;
fn mul(self, rhs: Self) -> Self::Output {
Self(self.0 * rhs.0 % 2)
}
}
impl Div for Z2 {
type Output = Z2;
fn div(self, rhs: Self) -> Self::Output {
Self(self.0 / rhs.0 % 2)
}
}
impl Zero for Z2 {
fn zero() -> Self {
Self(0)
}
fn is_zero(&self) -> bool {
self.0 == 0
}
}
impl One for Z2 {
fn one() -> Self {
Self(1)
}
}
impl Debug for Z2 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl PartialEq<Array2<Z2>> for BitMatrix<u8> {
fn eq(&self, other: &Array2<Z2>) -> bool {
other
.rows()
.into_iter()
.zip(self.iter_rows())
.all(|(r1, r2)| {
r1.iter()
.zip(r2)
.all(|(el1, el2)| matches!((el1, *el2), (Z2(0), false) | (Z2(1), true)))
})
}
}
}