indexical/
matrix.rs

1use rustc_hash::FxHashMap;
2use std::{fmt, hash::Hash};
3
4use crate::{
5    IndexSet, IndexedDomain, IndexedValue, ToIndex, bitset::BitSet, pointer::PointerFamily,
6};
7
8/// An unordered collections of pairs `(R, C)`, implemented with a sparse bit-matrix.
9///
10/// "Sparse" means "hash map from rows to bit-sets of columns". Subsequently, only column types `C` must be indexed,
11/// while row types `R` only need be hashable.
12pub struct IndexMatrix<'a, R, C: IndexedValue + 'a, S: BitSet, P: PointerFamily<'a>> {
13    pub(crate) matrix: FxHashMap<R, IndexSet<'a, C, S, P>>,
14    empty_set: IndexSet<'a, C, S, P>,
15    col_domain: P::Pointer<IndexedDomain<C>>,
16}
17
18impl<'a, R, C, S, P> IndexMatrix<'a, R, C, S, P>
19where
20    R: PartialEq + Eq + Hash + Clone,
21    C: IndexedValue + 'a,
22    S: BitSet,
23    P: PointerFamily<'a>,
24{
25    /// Creates an empty matrix.
26    pub fn new(col_domain: &P::Pointer<IndexedDomain<C>>) -> Self {
27        IndexMatrix {
28            matrix: FxHashMap::default(),
29            empty_set: IndexSet::new(col_domain),
30            col_domain: col_domain.clone(),
31        }
32    }
33
34    pub(crate) fn ensure_row(&mut self, row: R) -> &mut IndexSet<'a, C, S, P> {
35        self.matrix
36            .entry(row)
37            .or_insert_with(|| self.empty_set.clone())
38    }
39
40    /// Inserts a pair `(row, col)` into the matrix, returning true if `self` changed.
41    pub fn insert<M>(&mut self, row: R, col: impl ToIndex<C, M>) -> bool {
42        let col = col.to_index(&self.col_domain);
43        self.ensure_row(row).insert(col)
44    }
45
46    /// Adds all elements of `from` into the row `into`.
47    pub fn union_into_row(&mut self, into: R, from: &IndexSet<'a, C, S, P>) -> bool {
48        self.ensure_row(into).union_changed(from)
49    }
50
51    /// Adds all elements from the row `from` into the row `into`.
52    pub fn union_rows(&mut self, from: R, to: R) -> bool {
53        if from == to {
54            return false;
55        }
56
57        self.ensure_row(from.clone());
58        self.ensure_row(to.clone());
59
60        // SAFETY: `from` != `to` therefore this is a disjoint mutable borrow
61        let [Some(from), Some(to)] =
62            (unsafe { self.matrix.get_disjoint_unchecked_mut([&from, &to]) })
63        else {
64            unreachable!()
65        };
66        to.union_changed(from)
67    }
68
69    /// Returns an iterator over the elements in `row`.
70    pub fn row(&self, row: &R) -> impl Iterator<Item = &C> {
71        self.matrix.get(row).into_iter().flat_map(|set| set.iter())
72    }
73
74    /// Returns an iterator over all rows in the matrix.
75    pub fn rows(&self) -> impl ExactSizeIterator<Item = (&R, &IndexSet<'a, C, S, P>)> {
76        self.matrix.iter()
77    }
78
79    /// Returns the [`IndexSet`] for a particular `row`.
80    pub fn row_set(&self, row: &R) -> &IndexSet<'a, C, S, P> {
81        self.matrix.get(row).unwrap_or(&self.empty_set)
82    }
83
84    /// Clears all the elements from the `row`.
85    pub fn clear_row(&mut self, row: &R) {
86        self.matrix.remove(row);
87    }
88
89    /// Returns the [`IndexedDomain`] for the column type.
90    pub fn col_domain(&self) -> &P::Pointer<IndexedDomain<C>> {
91        &self.col_domain
92    }
93}
94
95impl<'a, R, C, S, P> IndexMatrix<'a, R, C, S, P>
96where
97    R: PartialEq + Eq + Hash + Clone + 'a,
98    C: IndexedValue + 'a,
99    S: BitSet,
100    P: PointerFamily<'a>,
101{
102    // R = Local
103    // C = Allocation
104
105    /// Transposes the matrix, assuming the row type is also indexed.
106    pub fn transpose<T, M>(
107        &self,
108        row_domain: &P::Pointer<IndexedDomain<T>>,
109    ) -> IndexMatrix<'a, C::Index, T, S, P>
110    where
111        T: IndexedValue + 'a,
112        R: ToIndex<T, M>,
113    {
114        let mut mtx = IndexMatrix::new(row_domain);
115        for (row, cols) in self.rows() {
116            for col in cols.indices() {
117                mtx.insert(col, row.clone());
118            }
119        }
120        mtx
121    }
122}
123
124impl<'a, R, C, S, P> PartialEq for IndexMatrix<'a, R, C, S, P>
125where
126    R: PartialEq + Eq + Hash + Clone,
127    C: IndexedValue + 'a,
128    S: BitSet,
129    P: PointerFamily<'a>,
130{
131    fn eq(&self, other: &Self) -> bool {
132        self.matrix == other.matrix
133    }
134}
135
136impl<'a, R, C, S, P> Eq for IndexMatrix<'a, R, C, S, P>
137where
138    R: PartialEq + Eq + Hash + Clone,
139    C: IndexedValue + 'a,
140    S: BitSet,
141    P: PointerFamily<'a>,
142{
143}
144
145impl<'a, R, C, S, P> Clone for IndexMatrix<'a, R, C, S, P>
146where
147    R: PartialEq + Eq + Hash + Clone,
148    C: IndexedValue + 'a,
149    S: BitSet,
150    P: PointerFamily<'a>,
151{
152    fn clone(&self) -> Self {
153        Self {
154            matrix: self.matrix.clone(),
155            empty_set: self.empty_set.clone(),
156            col_domain: self.col_domain.clone(),
157        }
158    }
159
160    fn clone_from(&mut self, source: &Self) {
161        for col in self.matrix.values_mut() {
162            col.clear();
163        }
164
165        for (row, col) in source.matrix.iter() {
166            self.ensure_row(row.clone()).clone_from(col);
167        }
168
169        self.empty_set = source.empty_set.clone();
170        self.col_domain = source.col_domain.clone();
171    }
172}
173
174impl<'a, R, C, S, P> fmt::Debug for IndexMatrix<'a, R, C, S, P>
175where
176    R: PartialEq + Eq + Hash + Clone + fmt::Debug,
177    C: IndexedValue + fmt::Debug + 'a,
178    S: BitSet,
179    P: PointerFamily<'a>,
180{
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        f.debug_map().entries(self.rows()).finish()
183    }
184}
185
186#[cfg(test)]
187mod test {
188    use crate::{IndexedDomain, test_utils::TestIndexMatrix};
189    use std::rc::Rc;
190
191    fn mk(s: &str) -> String {
192        s.to_string()
193    }
194
195    #[test]
196    fn test_indexmatrix() {
197        let col_domain = Rc::new(IndexedDomain::from_iter([mk("a"), mk("b"), mk("c")]));
198        let mut mtx = TestIndexMatrix::new(&col_domain);
199        mtx.insert(0, mk("b"));
200        mtx.insert(1, mk("c"));
201        assert_eq!(mtx.row(&0).collect::<Vec<_>>(), vec!["b"]);
202        assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["c"]);
203
204        assert!(mtx.union_rows(0, 1));
205        assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["b", "c"]);
206    }
207}