base64/engine/general_purpose/
decode.rs1use crate::{
2 engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
3 DecodeError, PAD_BYTE,
4};
5
6const INPUT_CHUNK_LEN: usize = 8;
8const DECODED_CHUNK_LEN: usize = 6;
9
10const DECODED_CHUNK_SUFFIX: usize = 2;
14
15const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
17
18const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
19
20const DECODED_BLOCK_LEN: usize =
22 CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
23
24#[doc(hidden)]
25pub struct GeneralPurposeEstimate {
26 num_chunks: usize,
28 decoded_len_estimate: usize,
29}
30
31impl GeneralPurposeEstimate {
32 pub(crate) fn new(encoded_len: usize) -> Self {
33 Self {
35 num_chunks: encoded_len / INPUT_CHUNK_LEN
36 + (encoded_len % INPUT_CHUNK_LEN > 0) as usize,
37 decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3,
38 }
39 }
40}
41
42impl DecodeEstimate for GeneralPurposeEstimate {
43 fn decoded_len_estimate(&self) -> usize {
44 self.decoded_len_estimate
45 }
46}
47
48#[inline]
54pub(crate) fn decode_helper(
55 input: &[u8],
56 estimate: GeneralPurposeEstimate,
57 output: &mut [u8],
58 decode_table: &[u8; 256],
59 decode_allow_trailing_bits: bool,
60 padding_mode: DecodePaddingMode,
61) -> Result<DecodeMetadata, DecodeError> {
62 let remainder_len = input.len() % INPUT_CHUNK_LEN;
63
64 let trailing_bytes_to_skip = match remainder_len {
69 0 => INPUT_CHUNK_LEN,
72 1 | 5 => {
74 if let Some(b) = input.last() {
77 if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE {
78 return Err(DecodeError::InvalidByte(input.len() - 1, *b));
79 }
80 }
81
82 return Err(DecodeError::InvalidLength);
83 }
84 2 => INPUT_CHUNK_LEN + 2,
88 3 => INPUT_CHUNK_LEN + 3,
93 4 => INPUT_CHUNK_LEN + 4,
96 _ => remainder_len,
99 };
100
101 let mut remaining_chunks = estimate.num_chunks;
103
104 let mut input_index = 0;
105 let mut output_index = 0;
106
107 {
108 let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
109
110 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
113 while input_index <= max_start_index {
114 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
115 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
116
117 decode_chunk(
118 &input_slice[0..],
119 input_index,
120 decode_table,
121 &mut output_slice[0..],
122 )?;
123 decode_chunk(
124 &input_slice[8..],
125 input_index + 8,
126 decode_table,
127 &mut output_slice[6..],
128 )?;
129 decode_chunk(
130 &input_slice[16..],
131 input_index + 16,
132 decode_table,
133 &mut output_slice[12..],
134 )?;
135 decode_chunk(
136 &input_slice[24..],
137 input_index + 24,
138 decode_table,
139 &mut output_slice[18..],
140 )?;
141
142 input_index += INPUT_BLOCK_LEN;
143 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
144 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
145 }
146 }
147
148 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
151 while input_index < max_start_index {
152 decode_chunk(
153 &input[input_index..(input_index + INPUT_CHUNK_LEN)],
154 input_index,
155 decode_table,
156 &mut output
157 [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
158 )?;
159
160 output_index += DECODED_CHUNK_LEN;
161 input_index += INPUT_CHUNK_LEN;
162 remaining_chunks -= 1;
163 }
164 }
165 }
166
167 for _ in 1..remaining_chunks {
175 decode_chunk_precise(
176 &input[input_index..],
177 input_index,
178 decode_table,
179 &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
180 )?;
181
182 input_index += INPUT_CHUNK_LEN;
183 output_index += DECODED_CHUNK_LEN;
184 }
185
186 debug_assert!(input.len() - input_index > 1 || input.is_empty());
188 debug_assert!(input.len() - input_index <= 8);
189
190 super::decode_suffix::decode_suffix(
191 input,
192 input_index,
193 output,
194 output_index,
195 decode_table,
196 decode_allow_trailing_bits,
197 padding_mode,
198 )
199}
200
201#[inline(always)]
212fn decode_chunk(
213 input: &[u8],
214 index_at_start_of_input: usize,
215 decode_table: &[u8; 256],
216 output: &mut [u8],
217) -> Result<(), DecodeError> {
218 let morsel = decode_table[input[0] as usize];
219 if morsel == INVALID_VALUE {
220 return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
221 }
222 let mut accum = (morsel as u64) << 58;
223
224 let morsel = decode_table[input[1] as usize];
225 if morsel == INVALID_VALUE {
226 return Err(DecodeError::InvalidByte(
227 index_at_start_of_input + 1,
228 input[1],
229 ));
230 }
231 accum |= (morsel as u64) << 52;
232
233 let morsel = decode_table[input[2] as usize];
234 if morsel == INVALID_VALUE {
235 return Err(DecodeError::InvalidByte(
236 index_at_start_of_input + 2,
237 input[2],
238 ));
239 }
240 accum |= (morsel as u64) << 46;
241
242 let morsel = decode_table[input[3] as usize];
243 if morsel == INVALID_VALUE {
244 return Err(DecodeError::InvalidByte(
245 index_at_start_of_input + 3,
246 input[3],
247 ));
248 }
249 accum |= (morsel as u64) << 40;
250
251 let morsel = decode_table[input[4] as usize];
252 if morsel == INVALID_VALUE {
253 return Err(DecodeError::InvalidByte(
254 index_at_start_of_input + 4,
255 input[4],
256 ));
257 }
258 accum |= (morsel as u64) << 34;
259
260 let morsel = decode_table[input[5] as usize];
261 if morsel == INVALID_VALUE {
262 return Err(DecodeError::InvalidByte(
263 index_at_start_of_input + 5,
264 input[5],
265 ));
266 }
267 accum |= (morsel as u64) << 28;
268
269 let morsel = decode_table[input[6] as usize];
270 if morsel == INVALID_VALUE {
271 return Err(DecodeError::InvalidByte(
272 index_at_start_of_input + 6,
273 input[6],
274 ));
275 }
276 accum |= (morsel as u64) << 22;
277
278 let morsel = decode_table[input[7] as usize];
279 if morsel == INVALID_VALUE {
280 return Err(DecodeError::InvalidByte(
281 index_at_start_of_input + 7,
282 input[7],
283 ));
284 }
285 accum |= (morsel as u64) << 16;
286
287 write_u64(output, accum);
288
289 Ok(())
290}
291
292#[inline]
295fn decode_chunk_precise(
296 input: &[u8],
297 index_at_start_of_input: usize,
298 decode_table: &[u8; 256],
299 output: &mut [u8],
300) -> Result<(), DecodeError> {
301 let mut tmp_buf = [0_u8; 8];
302
303 decode_chunk(
304 input,
305 index_at_start_of_input,
306 decode_table,
307 &mut tmp_buf[..],
308 )?;
309
310 output[0..6].copy_from_slice(&tmp_buf[0..6]);
311
312 Ok(())
313}
314
315#[inline]
316fn write_u64(output: &mut [u8], value: u64) {
317 output[..8].copy_from_slice(&value.to_be_bytes());
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 use crate::engine::general_purpose::STANDARD;
325
326 #[test]
327 fn decode_chunk_precise_writes_only_6_bytes() {
328 let input = b"Zm9vYmFy"; let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
330
331 decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
332 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
333 }
334
335 #[test]
336 fn decode_chunk_writes_8_bytes() {
337 let input = b"Zm9vYmFy"; let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
339
340 decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
341 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
342 }
343
344 #[test]
345 fn estimate_short_lengths() {
346 for (range, (num_chunks, decoded_len_estimate)) in [
347 (0..=0, (0, 0)),
348 (1..=4, (1, 3)),
349 (5..=8, (1, 6)),
350 (9..=12, (2, 9)),
351 (13..=16, (2, 12)),
352 (17..=20, (3, 15)),
353 ] {
354 for encoded_len in range {
355 let estimate = GeneralPurposeEstimate::new(encoded_len);
356 assert_eq!(num_chunks, estimate.num_chunks);
357 assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate);
358 }
359 }
360 }
361
362 #[test]
363 fn estimate_via_u128_inflation() {
364 (0..1000)
366 .chain(usize::MAX - 1000..=usize::MAX)
367 .for_each(|encoded_len| {
368 let len_128 = encoded_len as u128;
370
371 let estimate = GeneralPurposeEstimate::new(encoded_len);
372 assert_eq!(
373 ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128))
374 as usize,
375 estimate.num_chunks
376 );
377 assert_eq!(
378 ((len_128 + 3) / 4 * 3) as usize,
379 estimate.decoded_len_estimate
380 );
381 })
382 }
383}