1use std::{
4 io::Write,
5 path::Path,
6 process::{Command, Stdio},
7};
8
9use anyhow::{Result, ensure};
10use pretty::PrettyPrintMirOptions;
11use rustc_data_structures::fx::FxHashMap as HashMap;
12use rustc_hir::{CoroutineDesugaring, CoroutineKind, HirId, def_id::DefId};
13use rustc_middle::{
14 mir::{
15 BasicBlock, Body, Local, Location, Place, SourceInfo, TerminatorKind,
16 VarDebugInfoContents, pretty, pretty::write_mir_fn,
17 },
18 ty::{Region, Ty, TyCtxt},
19};
20use smallvec::SmallVec;
21
22use super::control_dependencies::ControlDependencies;
23use crate::{PlaceExt, TyExt};
24
25pub trait BodyExt<'tcx> {
27 fn all_returns(&self) -> impl Iterator<Item = Location> + '_;
29
30 fn all_locations(&self) -> impl Iterator<Item = Location> + '_;
32
33 fn locations_in_block(&self, block: BasicBlock) -> impl Iterator<Item = Location>;
35
36 fn debug_info_name_map(&self) -> HashMap<String, Local>;
38
39 fn to_string(&self, tcx: TyCtxt<'tcx>) -> Result<String>;
41
42 fn location_to_hir_id(&self, location: Location) -> HirId;
47
48 fn source_info_to_hir_id(&self, info: &SourceInfo) -> HirId;
49
50 fn control_dependencies(&self) -> ControlDependencies<BasicBlock>;
55
56 fn async_context(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Ty<'tcx>>;
59
60 fn all_places(
62 &self,
63 tcx: TyCtxt<'tcx>,
64 def_id: DefId,
65 ) -> impl Iterator<Item = Place<'tcx>> + '_;
66
67 fn regions_in_args(&self) -> impl Iterator<Item = Region<'tcx>> + '_;
69
70 fn regions_in_return(&self) -> impl Iterator<Item = Region<'tcx>> + '_;
72}
73
74impl<'tcx> BodyExt<'tcx> for Body<'tcx> {
75 fn all_returns(&self) -> impl Iterator<Item = Location> + '_ {
76 self
77 .basic_blocks
78 .iter_enumerated()
79 .filter_map(|(block, data)| match data.terminator().kind {
80 TerminatorKind::Return => Some(Location {
81 block,
82 statement_index: data.statements.len(),
83 }),
84 _ => None,
85 })
86 }
87
88 fn all_locations(&self) -> impl Iterator<Item = Location> + '_ {
89 self
90 .basic_blocks
91 .iter_enumerated()
92 .flat_map(|(block, data)| {
93 (0 ..= data.statements.len()).map(move |statement_index| Location {
94 block,
95 statement_index,
96 })
97 })
98 }
99
100 fn locations_in_block(&self, block: BasicBlock) -> impl Iterator<Item = Location> {
101 let num_stmts = self.basic_blocks[block].statements.len();
102 (0 ..= num_stmts).map(move |statement_index| Location {
103 block,
104 statement_index,
105 })
106 }
107
108 fn debug_info_name_map(&self) -> HashMap<String, Local> {
109 self
110 .var_debug_info
111 .iter()
112 .filter_map(|info| match info.value {
113 VarDebugInfoContents::Place(place) => Some((info.name.to_string(), place.local)),
114 _ => None,
115 })
116 .collect()
117 }
118
119 fn to_string(&self, tcx: TyCtxt<'tcx>) -> Result<String> {
120 let mut buffer = Vec::new();
121 write_mir_fn(
122 tcx,
123 self,
124 &mut |_, _| Ok(()),
125 &mut buffer,
126 PrettyPrintMirOptions {
127 include_extra_comments: false,
128 },
129 )?;
130 Ok(String::from_utf8(buffer)?)
131 }
132
133 fn location_to_hir_id(&self, location: Location) -> HirId {
134 let source_info = self.source_info(location);
135 self.source_info_to_hir_id(source_info)
136 }
137
138 fn source_info_to_hir_id(&self, info: &SourceInfo) -> HirId {
139 let scope = &self.source_scopes[info.scope];
140 let local_data = scope.local_data.as_ref().unwrap_crate_local();
141 local_data.lint_root
142 }
143
144 fn control_dependencies(&self) -> ControlDependencies<BasicBlock> {
145 ControlDependencies::build_many(
146 &self.basic_blocks,
147 self.all_returns().map(|loc| loc.block),
148 )
149 }
150
151 fn async_context(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Ty<'tcx>> {
152 if matches!(
153 tcx.coroutine_kind(def_id),
154 Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _))
155 ) {
156 Some(self.local_decls[Local::from_usize(2)].ty)
157 } else {
158 None
159 }
160 }
161
162 fn regions_in_args(&self) -> impl Iterator<Item = Region<'tcx>> + '_ {
163 self
164 .args_iter()
165 .flat_map(|arg_local| self.local_decls[arg_local].ty.inner_regions())
166 }
167
168 fn regions_in_return(&self) -> impl Iterator<Item = Region<'tcx>> + '_ {
169 self
170 .return_ty()
171 .inner_regions()
172 .collect::<SmallVec<[Region<'tcx>; 8]>>()
173 .into_iter()
174 }
175
176 fn all_places(
177 &self,
178 tcx: TyCtxt<'tcx>,
179 def_id: DefId,
180 ) -> impl Iterator<Item = Place<'tcx>> + '_ {
181 self.local_decls.indices().flat_map(move |local| {
182 Place::from_local(local, tcx).interior_paths(tcx, self, def_id)
183 })
184 }
185}
186
187pub fn run_dot(path: &Path, buf: &[u8]) -> Result<()> {
188 let mut p = Command::new("dot")
189 .args(["-Tpdf", "-o", &path.display().to_string()])
190 .stdin(Stdio::piped())
191 .spawn()?;
192
193 p.stdin.as_mut().unwrap().write_all(buf)?;
194
195 let status = p.wait()?;
196 ensure!(status.success(), "dot for {} failed", path.display());
197
198 Ok(())
199}
200
201#[cfg(test)]
202mod test {
203 use super::BodyExt;
204 use crate::test_utils;
205
206 #[test]
207 fn test_body_ext() {
208 let input = r"
209fn foobar<'a>(x: &'a i32, y: &'a i32) -> &'a i32 {
210 if *x > 0 {
211 return x;
212 }
213
214 y
215}";
216
217 test_utils::compile_body(input, |_, _, body| {
218 let body = &body.body;
219 assert_eq!(body.regions_in_args().count(), 2);
220 assert_eq!(body.regions_in_return().count(), 1);
221 });
222 }
223}