1use rustc_hash::FxHashMap;
2use std::{fmt, hash::Hash};
3
4use crate::{
5 IndexSet, IndexedDomain, IndexedValue, ToIndex, bitset::BitSet, pointer::PointerFamily,
6};
7
8pub struct IndexMatrix<'a, R, C: IndexedValue + 'a, S: BitSet, P: PointerFamily<'a>> {
13 pub(crate) matrix: FxHashMap<R, IndexSet<'a, C, S, P>>,
14 empty_set: IndexSet<'a, C, S, P>,
15 col_domain: P::Pointer<IndexedDomain<C>>,
16}
17
18impl<'a, R, C, S, P> IndexMatrix<'a, R, C, S, P>
19where
20 R: PartialEq + Eq + Hash + Clone,
21 C: IndexedValue + 'a,
22 S: BitSet,
23 P: PointerFamily<'a>,
24{
25 pub fn new(col_domain: &P::Pointer<IndexedDomain<C>>) -> Self {
27 IndexMatrix {
28 matrix: FxHashMap::default(),
29 empty_set: IndexSet::new(col_domain),
30 col_domain: col_domain.clone(),
31 }
32 }
33
34 pub(crate) fn ensure_row(&mut self, row: R) -> &mut IndexSet<'a, C, S, P> {
35 self.matrix
36 .entry(row)
37 .or_insert_with(|| self.empty_set.clone())
38 }
39
40 pub fn insert<M>(&mut self, row: R, col: impl ToIndex<C, M>) -> bool {
42 let col = col.to_index(&self.col_domain);
43 self.ensure_row(row).insert(col)
44 }
45
46 pub fn union_into_row(&mut self, into: R, from: &IndexSet<'a, C, S, P>) -> bool {
48 self.ensure_row(into).union_changed(from)
49 }
50
51 pub fn union_rows(&mut self, from: R, to: R) -> bool {
53 if from == to {
54 return false;
55 }
56
57 self.ensure_row(from.clone());
58 self.ensure_row(to.clone());
59
60 let [Some(from), Some(to)] =
62 (unsafe { self.matrix.get_disjoint_unchecked_mut([&from, &to]) })
63 else {
64 unreachable!()
65 };
66 to.union_changed(from)
67 }
68
69 pub fn row(&self, row: &R) -> impl Iterator<Item = &C> {
71 self.matrix.get(row).into_iter().flat_map(|set| set.iter())
72 }
73
74 pub fn rows(&self) -> impl ExactSizeIterator<Item = (&R, &IndexSet<'a, C, S, P>)> {
76 self.matrix.iter()
77 }
78
79 pub fn row_set(&self, row: &R) -> &IndexSet<'a, C, S, P> {
81 self.matrix.get(row).unwrap_or(&self.empty_set)
82 }
83
84 pub fn clear_row(&mut self, row: &R) {
86 self.matrix.remove(row);
87 }
88
89 pub fn col_domain(&self) -> &P::Pointer<IndexedDomain<C>> {
91 &self.col_domain
92 }
93}
94
95impl<'a, R, C, S, P> IndexMatrix<'a, R, C, S, P>
96where
97 R: PartialEq + Eq + Hash + Clone + 'a,
98 C: IndexedValue + 'a,
99 S: BitSet,
100 P: PointerFamily<'a>,
101{
102 pub fn transpose<T, M>(
107 &self,
108 row_domain: &P::Pointer<IndexedDomain<T>>,
109 ) -> IndexMatrix<'a, C::Index, T, S, P>
110 where
111 T: IndexedValue + 'a,
112 R: ToIndex<T, M>,
113 {
114 let mut mtx = IndexMatrix::new(row_domain);
115 for (row, cols) in self.rows() {
116 for col in cols.indices() {
117 mtx.insert(col, row.clone());
118 }
119 }
120 mtx
121 }
122}
123
124impl<'a, R, C, S, P> PartialEq for IndexMatrix<'a, R, C, S, P>
125where
126 R: PartialEq + Eq + Hash + Clone,
127 C: IndexedValue + 'a,
128 S: BitSet,
129 P: PointerFamily<'a>,
130{
131 fn eq(&self, other: &Self) -> bool {
132 self.matrix == other.matrix
133 }
134}
135
136impl<'a, R, C, S, P> Eq for IndexMatrix<'a, R, C, S, P>
137where
138 R: PartialEq + Eq + Hash + Clone,
139 C: IndexedValue + 'a,
140 S: BitSet,
141 P: PointerFamily<'a>,
142{
143}
144
145impl<'a, R, C, S, P> Clone for IndexMatrix<'a, R, C, S, P>
146where
147 R: PartialEq + Eq + Hash + Clone,
148 C: IndexedValue + 'a,
149 S: BitSet,
150 P: PointerFamily<'a>,
151{
152 fn clone(&self) -> Self {
153 Self {
154 matrix: self.matrix.clone(),
155 empty_set: self.empty_set.clone(),
156 col_domain: self.col_domain.clone(),
157 }
158 }
159
160 fn clone_from(&mut self, source: &Self) {
161 for col in self.matrix.values_mut() {
162 col.clear();
163 }
164
165 for (row, col) in source.matrix.iter() {
166 self.ensure_row(row.clone()).clone_from(col);
167 }
168
169 self.empty_set = source.empty_set.clone();
170 self.col_domain = source.col_domain.clone();
171 }
172}
173
174impl<'a, R, C, S, P> fmt::Debug for IndexMatrix<'a, R, C, S, P>
175where
176 R: PartialEq + Eq + Hash + Clone + fmt::Debug,
177 C: IndexedValue + fmt::Debug + 'a,
178 S: BitSet,
179 P: PointerFamily<'a>,
180{
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 f.debug_map().entries(self.rows()).finish()
183 }
184}
185
186#[cfg(test)]
187mod test {
188 use crate::{IndexedDomain, test_utils::TestIndexMatrix};
189 use std::rc::Rc;
190
191 fn mk(s: &str) -> String {
192 s.to_string()
193 }
194
195 #[test]
196 fn test_indexmatrix() {
197 let col_domain = Rc::new(IndexedDomain::from_iter([mk("a"), mk("b"), mk("c")]));
198 let mut mtx = TestIndexMatrix::new(&col_domain);
199 mtx.insert(0, mk("b"));
200 mtx.insert(1, mk("c"));
201 assert_eq!(mtx.row(&0).collect::<Vec<_>>(), vec!["b"]);
202 assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["c"]);
203
204 assert!(mtx.union_rows(0, 1));
205 assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["b", "c"]);
206 }
207}