- Commit
- 652d3e236e94bca8103d9b7c5de81ab7093e8bc8
- Parent
- 590fa2cc822167c0f918770812cddc5471ffb7f6
- Author
- Pablo <pablo-pie@riseup.net>
- Date
Cleaned the SIMD escaping code
Yet another static site generator for Git 🙀️
Cleaned the SIMD escaping code
1 files changed, 96 insertions, 90 deletions
Status | Name | Changes | Insertions | Deletions |
Modified | src/escape.rs | 2 files changed | 96 | 90 |
diff --git a/src/escape.rs b/src/escape.rs @@ -10,14 +10,15 @@ static ESCAPE_TABLE: [Option<&str>; 256] = create_html_escape_table(); /// A wrapper for HTML-escaped strings pub struct Escaped<'a>(pub &'a str); -// Stolen from pulldown-cmark-escape +// stolen from pulldown-cmark-escape impl Display for Escaped<'_> { #[cfg(target_arch = "x86_64")] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // The SIMD accelerated code uses the PSHUFB instruction, which is part - // of the SSSE3 instruction set. Further, we can only use this code if - // the buffer is at least one VECTOR_SIZE in length to prevent reading - // out of bounds + // the SIMD accelerated code uses the PSHUFB instruction, which is part + // of the SSSE3 instruction set + // + // further, we can only use this code if the buffer is at least one + // VECTOR_SIZE in length to prevent reading out of bounds if is_x86_feature_detected!("ssse3") && self.0.len() >= simd::VECTOR_SIZE { simd::fmt_escaped_html(self.0, f) } else { @@ -31,7 +32,7 @@ impl Display for Escaped<'_> { } } -// Stolen from pulldown-cmark-escape +// stolen from pulldown-cmark-escape fn fmt_escaped_html_scalar( s: &str, f: &mut fmt::Formatter<'_> @@ -89,114 +90,119 @@ mod simd { pub const VECTOR_SIZE: usize = mem::size_of::<__m128i>(); pub fn fmt_escaped_html(s: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // the strategy here is to walk s in chunks of VECTOR_SIZE (16) bytes at + // a time: + // + // 1. for each chunk, we compute a a bitmask indicating whether the + // corresponding byte is a HTML special byte + // 2. for each bit set in this mask, we print the escaped character + // accordingly, as well as the surrounding characters that don't need + // escaping + // + // when the number of HTML special bytes in the buffer is relatively low, + // this allows us to quickly go through the buffer without a lookup and + // for every single byte let bytes = s.as_bytes(); + debug_assert!(bytes.len() >= VECTOR_SIZE); + let mut mark = 0; + let mut offset = 0; unsafe { - foreach_special_simd(bytes, 0, |i| { - let c = *bytes.get_unchecked(i) as usize; - let entry = super::ESCAPE_TABLE[c]; - f.write_str(s.get_unchecked(mark..i))?; - mark = i + 1; // all escaped characters are ASCII - if let Some(replacement) = entry { - f.write_str(replacement) - } else { - f.write_str(s.get_unchecked(i..mark)) - } - })?; - f.write_str(s.get_unchecked(mark..)) - } - } + let upperbound = bytes.len() - VECTOR_SIZE; + while offset < upperbound { + let mut mask = compute_mask(bytes, offset); + + while mask != 0 { + let first_special = mask.trailing_zeros(); + let i = offset + first_special as usize; + let c = *bytes.get_unchecked(i) as usize; + + // here we know c = s[i] is a character that should be escaped, + // so it is safe to unwrap ESCAPE_TABLE[c] + let replacement = super::ESCAPE_TABLE[c].unwrap(); + f.write_str(s.get_unchecked(mark..i))?; + f.write_str(replacement)?; + + mark = i + 1; // all escaped characters are ASCII + mask ^= mask & -mask; + } + + offset += VECTOR_SIZE; + } - unsafe fn foreach_special_simd<F: FnMut(usize) -> fmt::Result>( - bytes: &[u8], - mut offset: usize, - mut callback: F, - ) -> fmt::Result { - // The strategy here is to walk the byte buffer in chunks of - // VECTOR_SIZE (16) bytes at a time starting at the given offset. - // For each chunk, we compute a a bitmask indicating whether the - // corresponding byte is a HTML special byte. We then iterate over all - // the 1 bits in this mask and call the callback function with the - // corresponding index in the buffer. - // - // When the number of HTML special bytes in the buffer is relatively low, - // this allows us to quickly go through the buffer without a lookup and - // for every single byte. + // ====================================================================== + // final iteration: we align the read with the end of the slice + // and shift off the bytes at start we have already scanned + let mut mask = compute_mask(bytes, upperbound); + mask >>= offset - upperbound; - debug_assert!(bytes.len() >= VECTOR_SIZE); - let upperbound = bytes.len() - VECTOR_SIZE; - while offset < upperbound { - let mut mask = compute_mask(bytes, offset); while mask != 0 { - let ix = mask.trailing_zeros(); - callback(offset + ix as usize)?; + let first_special = mask.trailing_zeros(); + let i = offset + first_special as usize; + let c = *bytes.get_unchecked(i) as usize; + + // here we know c = s[i] is a character that should be escaped, + // so it is safe to unwrap ESCAPE_TABLE[c] + let replacement = super::ESCAPE_TABLE[c].unwrap(); + f.write_str(s.get_unchecked(mark..i))?; + f.write_str(replacement)?; + + mark = i + 1; // all escaped characters are ASCII mask ^= mask & -mask; } - offset += VECTOR_SIZE; - } - // Final iteration. We align the read with the end of the slice and - // shift off the bytes at start we have already scanned. - let mut mask = compute_mask(bytes, upperbound); - mask >>= offset - upperbound; - while mask != 0 { - let ix = mask.trailing_zeros(); - callback(offset + ix as usize)?; - mask ^= mask & -mask; + f.write_str(s.get_unchecked(mark..)) } - Ok(()) } + #[target_feature(enable = "ssse3")] /// Computes a byte mask at given offset in the byte buffer. Its first 16 /// (least significant) bits correspond to whether there is an HTML special - /// byte (&, <, ", >) at the 16 bytes `bytes[offset..]`. For example, the - /// mask `(1 << 3)` states that there is an HTML byte at `offset + 3`. It is - /// only safe to call this function when `bytes.len() >= offset + + /// byte at the first VECTOR_SIZE bytes `bytes[offset..]`. + /// + /// It is only safe to call this function when `bytes.len() >= offset + /// VECTOR_SIZE`. unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 { debug_assert!(bytes.len() >= offset + VECTOR_SIZE); - let table = create_lookup(); - let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i); - let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i; + const LOOKUP_TABLE: [u8; VECTOR_SIZE] = create_lookup(); + const fn create_lookup() -> [u8; VECTOR_SIZE] { + let mut table = [0; VECTOR_SIZE]; + table[(b'<' & 0x0f) as usize] = b'<'; + table[(b'>' & 0x0f) as usize] = b'>'; + table[(b'&' & 0x0f) as usize] = b'&'; + table[(b'"' & 0x0f) as usize] = b'"'; + table[(b'\'' & 0x0f) as usize] = b'\''; + table[0] = 0b01111111; + table + } - // Load the vector from memory. + let lookup_table = _mm_loadu_si128( + LOOKUP_TABLE.as_ptr() as *const __m128i + ); + let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i; let vector = _mm_loadu_si128(raw_ptr); - // We take the least significant 4 bits of every byte and use them as - // indices to map into the lookup vector. + // mask the vector using the lookup table: // - // Note that shuffle maps bytes with their most significant bit set to - // lookup[0]. Bytes that share their lower nibble with an HTML special - // byte get mapped to that corresponding special byte. Note that all HTML - // special bytes have distinct lower nibbles. Other bytes either get - // mapped to 0 or 127. - let expected = _mm_shuffle_epi8(lookup, vector); - - // We compare the original vector to the mapped output. Bytes that shared - // a lower nibble with an HTML special byte match *only* if they are that - // special byte. Bytes that have either a 0 lower nibble or their most - // significant bit set were mapped to 127 and will hence never match. All - // other bytes have non-zero lower nibbles but were mapped to 0 and will - // therefore also not match. - let matches = _mm_cmpeq_epi8(expected, vector); - - // Translate matches to a bitmask, where every 1 corresponds to a HTML - // special character and a 0 is a non-HTML byte. - _mm_movemask_epi8(matches) - } + // 1. bytes whose lower nibbles are special HTML characters get mapped to + // their lower nibbles + // 2. bytes whose lower nibbles are nonzero and *not* special HTML + // characters get mapped to 0 + // 3. bytes whose lower nibbles are 0 get mapped to 0b01111111 + let masked = _mm_shuffle_epi8(lookup_table, vector); + + // compare the original vector to the masked one: + // + // 1. bytes that shared a lower nibble with an HTML special byte match + // *only* if they are that special byte + // 2. all other bytes will never match + let matches = _mm_cmpeq_epi8(masked, vector); - /// Creates the lookup table for use in `compute_mask`. - const fn create_lookup() -> [u8; 16] { - let mut table = [0; 16]; - table[(b'<' & 0x0f) as usize] = b'<'; - table[(b'>' & 0x0f) as usize] = b'>'; - table[(b'&' & 0x0f) as usize] = b'&'; - table[(b'"' & 0x0f) as usize] = b'"'; - table[(b'\'' & 0x0f) as usize] = b'\''; - table[0] = 0b0111_1111; - table + // translate matches to a bitmask: every 1 corresponds to a HTML + // special character and a 0 is a non-HTML byte + _mm_movemask_epi8(matches) } }