yagit

Yet another static site generator for Git 🙀️

Commit
652d3e236e94bca8103d9b7c5de81ab7093e8bc8
Parent
590fa2cc822167c0f918770812cddc5471ffb7f6
Author
Pablo <pablo-pie@riseup.net>
Date

Cleaned the SIMD escaping code

Diffstats

1 files changed, 96 insertions, 90 deletions

Status Name Changes Insertions Deletions
Modified src/escape.rs 2 files changed 96 90
diff --git a/src/escape.rs b/src/escape.rs
@@ -10,14 +10,15 @@ static ESCAPE_TABLE: [Option<&str>; 256] = create_html_escape_table();
 /// A wrapper for HTML-escaped strings
 pub struct Escaped<'a>(pub &'a str);
 
-// Stolen from pulldown-cmark-escape
+// stolen from pulldown-cmark-escape
 impl Display for Escaped<'_> {
   #[cfg(target_arch = "x86_64")]
   fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-    // The SIMD accelerated code uses the PSHUFB instruction, which is part
-    // of the SSSE3 instruction set. Further, we can only use this code if
-    // the buffer is at least one VECTOR_SIZE in length to prevent reading
-    // out of bounds
+    // the SIMD accelerated code uses the PSHUFB instruction, which is part
+    // of the SSSE3 instruction set
+    //
+    // further, we can only use this code if the buffer is at least one
+    // VECTOR_SIZE in length to prevent reading out of bounds
     if is_x86_feature_detected!("ssse3") && self.0.len() >= simd::VECTOR_SIZE {
       simd::fmt_escaped_html(self.0, f)
     } else {
@@ -31,7 +32,7 @@ impl Display for Escaped<'_> {
   }
 }
 
-// Stolen from pulldown-cmark-escape
+// stolen from pulldown-cmark-escape
 fn fmt_escaped_html_scalar(
   s: &str,
   f: &mut fmt::Formatter<'_>
@@ -89,114 +90,119 @@ mod simd {
   pub const VECTOR_SIZE: usize = mem::size_of::<__m128i>();
 
   pub fn fmt_escaped_html(s: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+    // the strategy here is to walk s in chunks of VECTOR_SIZE (16) bytes at
+    // a time:
+    //
+    // 1. for each chunk, we compute a a bitmask indicating whether the
+    //    corresponding byte is a HTML special byte
+    // 2. for each bit set in this mask, we print the escaped character
+    //    accordingly, as well as the surrounding characters that don't need
+    //    escaping
+    //
+    // when the number of HTML special bytes in the buffer is relatively low,
+    // this allows us to quickly go through the buffer without a lookup and
+    // for every single byte
     let bytes = s.as_bytes();
+    debug_assert!(bytes.len() >= VECTOR_SIZE);
+
     let mut mark = 0;
+    let mut offset = 0;
 
     unsafe {
-      foreach_special_simd(bytes, 0, |i| {
-        let c = *bytes.get_unchecked(i) as usize;
-        let entry = super::ESCAPE_TABLE[c];
-        f.write_str(s.get_unchecked(mark..i))?;
-        mark = i + 1; // all escaped characters are ASCII
-        if let Some(replacement) = entry {
-          f.write_str(replacement)
-        } else {
-          f.write_str(s.get_unchecked(i..mark))
-        } 
-      })?;
-      f.write_str(s.get_unchecked(mark..))
-    }
-  }
+      let upperbound = bytes.len() - VECTOR_SIZE;
+      while offset < upperbound {
+        let mut mask = compute_mask(bytes, offset);
+
+        while mask != 0 {
+          let first_special = mask.trailing_zeros();
+          let i = offset + first_special as usize;
+          let c = *bytes.get_unchecked(i) as usize;
+
+          // here we know c = s[i] is a character that should be escaped,
+          // so it is safe to unwrap ESCAPE_TABLE[c]
+          let replacement = super::ESCAPE_TABLE[c].unwrap();
+          f.write_str(s.get_unchecked(mark..i))?;
+          f.write_str(replacement)?;
+
+          mark = i + 1; // all escaped characters are ASCII
+          mask ^= mask & -mask;
+        }
+
+        offset += VECTOR_SIZE;
+      }
 
-  unsafe fn foreach_special_simd<F: FnMut(usize) -> fmt::Result>(
-    bytes: &[u8],
-    mut offset: usize,
-    mut callback: F,
-  ) -> fmt::Result {
-    // The strategy here is to walk the byte buffer in chunks of
-    // VECTOR_SIZE (16) bytes at a time starting at the given offset.
-    // For each chunk, we compute a a bitmask indicating whether the
-    // corresponding byte is a HTML special byte. We then iterate over all
-    // the 1 bits in this mask and call the callback function with the
-    // corresponding index in the buffer.
-    //
-    // When the number of HTML special bytes in the buffer is relatively low,
-    // this allows us to quickly go through the buffer without a lookup and
-    // for every single byte.
+      // ======================================================================
+      // final iteration: we align the read with the end of the slice
+      // and shift off the bytes at start we have already scanned
+      let mut mask = compute_mask(bytes, upperbound);
+      mask >>= offset - upperbound;
 
-    debug_assert!(bytes.len() >= VECTOR_SIZE);
-    let upperbound = bytes.len() - VECTOR_SIZE;
-    while offset < upperbound {
-      let mut mask = compute_mask(bytes, offset);
       while mask != 0 {
-        let ix = mask.trailing_zeros();
-        callback(offset + ix as usize)?;
+        let first_special = mask.trailing_zeros();
+        let i = offset + first_special as usize;
+        let c = *bytes.get_unchecked(i) as usize;
+
+        // here we know c = s[i] is a character that should be escaped,
+        // so it is safe to unwrap ESCAPE_TABLE[c]
+        let replacement = super::ESCAPE_TABLE[c].unwrap();
+        f.write_str(s.get_unchecked(mark..i))?;
+        f.write_str(replacement)?;
+
+        mark = i + 1; // all escaped characters are ASCII
         mask ^= mask & -mask;
       }
-      offset += VECTOR_SIZE;
-    }
 
-    // Final iteration. We align the read with the end of the slice and
-    // shift off the bytes at start we have already scanned.
-    let mut mask = compute_mask(bytes, upperbound);
-    mask >>= offset - upperbound;
-    while mask != 0 {
-      let ix = mask.trailing_zeros();
-      callback(offset + ix as usize)?;
-      mask ^= mask & -mask;
+      f.write_str(s.get_unchecked(mark..))
     }
-    Ok(())
   }
 
+
   #[target_feature(enable = "ssse3")]
   /// Computes a byte mask at given offset in the byte buffer. Its first 16
   /// (least significant) bits correspond to whether there is an HTML special
-  /// byte (&, <, ", >) at the 16 bytes `bytes[offset..]`. For example, the
-  /// mask `(1 << 3)` states that there is an HTML byte at `offset + 3`. It is
-  /// only safe to call this function when `bytes.len() >= offset +
+  /// byte at the first VECTOR_SIZE bytes `bytes[offset..]`.
+  ///
+  /// It is only safe to call this function when `bytes.len() >= offset +
   /// VECTOR_SIZE`.
   unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 {
     debug_assert!(bytes.len() >= offset + VECTOR_SIZE);
 
-    let table = create_lookup();
-    let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i);
-    let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i;
+    const LOOKUP_TABLE: [u8; VECTOR_SIZE] = create_lookup();
+    const fn create_lookup() -> [u8; VECTOR_SIZE] {
+      let mut table = [0; VECTOR_SIZE];
+      table[(b'<'  & 0x0f) as usize] = b'<';
+      table[(b'>'  & 0x0f) as usize] = b'>';
+      table[(b'&'  & 0x0f) as usize] = b'&';
+      table[(b'"'  & 0x0f) as usize] = b'"';
+      table[(b'\'' & 0x0f) as usize] = b'\'';
+      table[0]                       = 0b01111111;
+      table
+    }
 
-    // Load the vector from memory.
+    let lookup_table = _mm_loadu_si128(
+      LOOKUP_TABLE.as_ptr() as *const __m128i
+    );
+    let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i;
     let vector = _mm_loadu_si128(raw_ptr);
 
-    // We take the least significant 4 bits of every byte and use them as
-    // indices to map into the lookup vector.
+    // mask the vector using the lookup table:
     //
-    // Note that shuffle maps bytes with their most significant bit set to
-    // lookup[0]. Bytes that share their lower nibble with an HTML special
-    // byte get mapped to that corresponding special byte. Note that all HTML
-    // special bytes have distinct lower nibbles. Other bytes either get
-    // mapped to 0 or 127.
-    let expected = _mm_shuffle_epi8(lookup, vector);
-
-    // We compare the original vector to the mapped output. Bytes that shared
-    // a lower nibble with an HTML special byte match *only* if they are that
-    // special byte. Bytes that have either a 0 lower nibble or their most
-    // significant bit set were mapped to 127 and will hence never match. All
-    // other bytes have non-zero lower nibbles but were mapped to 0 and will
-    // therefore also not match.
-    let matches = _mm_cmpeq_epi8(expected, vector);
-
-    // Translate matches to a bitmask, where every 1 corresponds to a HTML
-    // special character and a 0 is a non-HTML byte.
-    _mm_movemask_epi8(matches)
-  }
+    // 1. bytes whose lower nibbles are special HTML characters get mapped to
+    //    their lower nibbles
+    // 2. bytes whose lower nibbles are nonzero and *not* special HTML
+    //    characters get mapped to 0
+    // 3. bytes whose lower nibbles are 0 get mapped to 0b01111111
+    let masked = _mm_shuffle_epi8(lookup_table, vector);
+
+    // compare the original vector to the masked one:
+    //
+    // 1. bytes that shared a lower nibble with an HTML special byte match
+    //    *only* if they are that special byte
+    // 2. all other bytes will never match
+    let matches = _mm_cmpeq_epi8(masked, vector);
 
-  /// Creates the lookup table for use in `compute_mask`.
-  const fn create_lookup() -> [u8; 16] {
-    let mut table = [0; 16];
-    table[(b'<' & 0x0f)  as usize] = b'<';
-    table[(b'>' & 0x0f)  as usize] = b'>';
-    table[(b'&' & 0x0f)  as usize] = b'&';
-    table[(b'"' & 0x0f)  as usize] = b'"';
-    table[(b'\'' & 0x0f) as usize] = b'\'';
-    table[0] = 0b0111_1111;
-    table
+    // translate matches to a bitmask: every 1 corresponds to a HTML
+    // special character and a 0 is a non-HTML byte
+    _mm_movemask_epi8(matches)
   }
 }