rustc_utils/source_map/spanner/
span_tree.rs

1use intervaltree::IntervalTree;
2use rustc_span::{BytePos, SpanData, source_map::Spanned};
3
4/// Interval tree data structure specialized to spans.
5pub struct SpanTree<T> {
6  tree: IntervalTree<BytePos, (SpanData, T)>,
7  len: usize,
8}
9
10impl<T> SpanTree<T> {
11  pub fn new(spans: impl IntoIterator<Item = Spanned<T>>) -> Self {
12    let tree = spans
13      .into_iter()
14      .map(|spanned| {
15        let data = spanned.span.data();
16        (data.lo .. data.hi, (data, spanned.node))
17      })
18      .collect::<IntervalTree<_, _>>();
19    let len = tree.iter().count();
20    SpanTree { tree, len }
21  }
22
23  pub fn len(&self) -> usize {
24    self.len
25  }
26
27  pub fn iter(&self) -> impl Iterator<Item = &'_ T> + '_ {
28    self.tree.iter().map(|el| &el.value.1)
29  }
30
31  /// Find all spans that overlap with `query`
32  pub fn overlapping(
33    &self,
34    query: SpanData,
35  ) -> impl Iterator<Item = &'_ (SpanData, T)> + '_ {
36    self.tree.query(query.lo .. query.hi).map(|el| &el.value)
37  }
38}
39
40#[cfg(test)]
41mod test {
42  use rustc_span::SyntaxContext;
43
44  use super::*;
45
46  #[test]
47  fn span_tree_test() {
48    rustc_span::create_default_session_globals_then(|| {
49      let mk_span = |lo, hi| SpanData {
50        lo: BytePos(lo),
51        hi: BytePos(hi),
52        ctxt: SyntaxContext::root(),
53        parent: None,
54      };
55      let mk = |node, lo, hi| Spanned {
56        span: mk_span(lo, hi).span(),
57        node,
58      };
59
60      let input = [mk("a", 0, 1), mk("b", 2, 3), mk("c", 0, 5)];
61      let tree = SpanTree::new(input);
62
63      let query = |lo, hi| {
64        let mut result = tree
65          .overlapping(mk_span(lo, hi))
66          .map(|(_, t)| t)
67          .copied()
68          .collect::<Vec<_>>();
69        result.sort_unstable();
70        result
71      };
72
73      assert_eq!(query(0, 2), ["a", "c"]);
74      assert_eq!(query(0, 3), ["a", "b", "c"]);
75      assert_eq!(query(2, 3), ["b", "c"]);
76      assert_eq!(query(6, 8), [] as [&str; 0]);
77    });
78  }
79}