yagit

Yet another static site generator for Git 🙀️

Commit
a8eee6f04ad9714334bf6ec098c3e307000cc8be
Parent
5b34e3ec5ba917d8c8201a38f8e78e99242194cf
Author
Pablo <pablo-pie@riseup.net>
Date

Made escape::simd::fmt_escaped_html safe

Diffstats

1 files changed, 24 insertions, 34 deletions

Status Name Changes Insertions Deletions
Modified src/escape.rs 2 files changed 24 34
diff --git a/src/escape.rs b/src/escape.rs
@@ -5,7 +5,16 @@
 
 use std::fmt::{self, Display};
 
-static ESCAPE_TABLE: [Option<&str>; 256] = create_html_escape_table();
+const ESCAPE_TABLE: [Option<&str>; 256] = create_escape_table();
+const fn create_escape_table() -> [Option<&'static str>; 256] {
+  let mut table = [None; 256];
+  table[b'<'  as usize] = Some("&lt;");
+  table[b'>'  as usize] = Some("&gt;");
+  table[b'&'  as usize] = Some("&amp;");
+  table[b'"'  as usize] = Some("&quot;");
+  table[b'\'' as usize] = Some("&apos;");
+  table
+}
 
 /// A wrapper for HTML-escaped strings
 pub struct Escaped<'a>(pub &'a str);
@@ -16,10 +25,7 @@ impl Display for Escaped<'_> {
   fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     // the SIMD accelerated code uses the PSHUFB instruction, which is part
     // of the SSSE3 instruction set
-    //
-    // further, we can only use this code if the buffer is at least one
-    // VECTOR_SIZE in length to prevent reading out of bounds
-    if is_x86_feature_detected!("ssse3") && self.0.len() >= simd::VECTOR_SIZE {
+    if is_x86_feature_detected!("ssse3") {
       simd::fmt_escaped_html(self.0, f)
     } else {
       fmt_escaped_html_scalar(self.0, f)
@@ -62,33 +68,14 @@ fn fmt_escaped_html_scalar(
   f.write_str(&s[mark..])
 }
 
-const fn create_html_escape_table() -> [Option<&'static str>; 256] {
-  let mut table = [None; 256];
-  table[b'<'  as usize] = Some("&lt;");
-  table[b'>'  as usize] = Some("&gt;");
-  table[b'&'  as usize] = Some("&amp;");
-  table[b'"'  as usize] = Some("&quot;");
-  table[b'\'' as usize] = Some("&apos;");
-  table
-}
-
 // stolen from pulldown-cmark-escape
 #[cfg(target_arch = "x86_64")]
 mod simd {
-  use std::{
-    arch::x86_64::{
-      __m128i,
-      _mm_loadu_si128,
-      _mm_shuffle_epi8,
-      _mm_cmpeq_epi8,
-      _mm_movemask_epi8,
-    },
-    mem,
-    fmt,
-  };
-
-  pub const VECTOR_SIZE: usize = mem::size_of::<__m128i>();
+  use std::{arch::x86_64::*, mem, fmt};
+
+  const VECTOR_SIZE: usize = mem::size_of::<__m128i>();
 
+  #[inline]
   pub fn fmt_escaped_html(s: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     // the strategy here is to walk s in chunks of VECTOR_SIZE (16) bytes at
     // a time:
@@ -102,9 +89,11 @@ mod simd {
     // when the number of HTML special bytes in the buffer is relatively low,
     // this allows us to quickly go through the buffer without a lookup and
     // for every single byte
-    let bytes = s.as_bytes();
-    debug_assert!(bytes.len() >= VECTOR_SIZE);
+    if s.len() < VECTOR_SIZE {
+      return super::fmt_escaped_html_scalar(s, f);
+    }
 
+    let bytes = s.as_bytes();
     let mut mark = 0;
     let mut offset = 0;
 
@@ -120,9 +109,9 @@ mod simd {
 
           // here we know c = s[i] is a character that should be escaped,
           // so it is safe to unwrap ESCAPE_TABLE[c]
-          let replacement = super::ESCAPE_TABLE[c].unwrap();
+          let escape_seq = super::ESCAPE_TABLE[c].unwrap();
           f.write_str(s.get_unchecked(mark..i))?;
-          f.write_str(replacement)?;
+          f.write_str(escape_seq)?;
 
           mark = i + 1; // all escaped characters are ASCII
           mask ^= mask & -mask;
@@ -144,9 +133,9 @@ mod simd {
 
         // here we know c = s[i] is a character that should be escaped,
         // so it is safe to unwrap ESCAPE_TABLE[c]
-        let replacement = super::ESCAPE_TABLE[c].unwrap();
+        let escape_seq = super::ESCAPE_TABLE[c].unwrap();
         f.write_str(s.get_unchecked(mark..i))?;
-        f.write_str(replacement)?;
+        f.write_str(escape_seq)?;
 
         mark = i + 1; // all escaped characters are ASCII
         mask ^= mask & -mask;
@@ -157,6 +146,7 @@ mod simd {
   }
 
 
+  #[inline]
   #[target_feature(enable = "ssse3")]
   /// Computes a byte mask at given offset in the byte buffer. Its first 16
   /// (least significant) bits correspond to whether there is an HTML special