From 703ff60d9f885dfe317e3d271d9f341509efac92 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Sat, 21 Jan 2023 10:17:06 +0100 Subject: [PATCH 1/4] Use NonNull in merge_sort This is more clear about the intent of the pointer and avoids problems if the allocation returns a null pointer. --- library/core/src/slice/sort.rs | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/library/core/src/slice/sort.rs b/library/core/src/slice/sort.rs index 2181f9a811855..7f8895b150fe7 100644 --- a/library/core/src/slice/sort.rs +++ b/library/core/src/slice/sort.rs @@ -1203,7 +1203,7 @@ pub fn merge_sort( // `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run, // which will always have length at most `len / 2`. let buf = BufGuard::new(len / 2, elem_alloc_fn, elem_dealloc_fn); - let buf_ptr = buf.buf_ptr; + let buf_ptr = buf.buf_ptr.as_ptr(); let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn); @@ -1298,7 +1298,7 @@ pub fn merge_sort( where ElemDeallocF: Fn(*mut T, usize), { - buf_ptr: *mut T, + buf_ptr: ptr::NonNull, capacity: usize, elem_dealloc_fn: ElemDeallocF, } @@ -1315,7 +1315,11 @@ pub fn merge_sort( where ElemAllocF: Fn(usize) -> *mut T, { - Self { buf_ptr: elem_alloc_fn(len), capacity: len, elem_dealloc_fn } + Self { + buf_ptr: ptr::NonNull::new(elem_alloc_fn(len)).unwrap(), + capacity: len, + elem_dealloc_fn, + } } } @@ -1324,7 +1328,7 @@ pub fn merge_sort( ElemDeallocF: Fn(*mut T, usize), { fn drop(&mut self) { - (self.elem_dealloc_fn)(self.buf_ptr, self.capacity); + (self.elem_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity); } } @@ -1333,7 +1337,7 @@ pub fn merge_sort( RunAllocF: Fn(usize) -> *mut TimSortRun, RunDeallocF: Fn(*mut TimSortRun, usize), { - buf_ptr: *mut TimSortRun, + buf_ptr: ptr::NonNull, capacity: usize, len: usize, run_alloc_fn: RunAllocF, @@ -1350,7 +1354,7 @@ pub fn merge_sort( const START_RUN_CAPACITY: usize = 16; Self { - buf_ptr: run_alloc_fn(START_RUN_CAPACITY), + buf_ptr: ptr::NonNull::new(run_alloc_fn(START_RUN_CAPACITY)).unwrap(), capacity: START_RUN_CAPACITY, len: 0, run_alloc_fn, @@ -1361,15 +1365,15 @@ pub fn merge_sort( fn push(&mut self, val: TimSortRun) { if self.len == self.capacity { let old_capacity = self.capacity; - let old_buf_ptr = self.buf_ptr; + let old_buf_ptr = self.buf_ptr.as_ptr(); self.capacity = self.capacity * 2; - self.buf_ptr = (self.run_alloc_fn)(self.capacity); + self.buf_ptr = ptr::NonNull::new((self.run_alloc_fn)(self.capacity)).unwrap(); // SAFETY: buf_ptr new and old were correctly allocated and old_buf_ptr has // old_capacity valid elements. unsafe { - ptr::copy_nonoverlapping(old_buf_ptr, self.buf_ptr, old_capacity); + ptr::copy_nonoverlapping(old_buf_ptr, self.buf_ptr.as_ptr(), old_capacity); } (self.run_dealloc_fn)(old_buf_ptr, old_capacity); @@ -1377,7 +1381,7 @@ pub fn merge_sort( // SAFETY: The invariant was just checked. unsafe { - self.buf_ptr.add(self.len).write(val); + self.buf_ptr.as_ptr().add(self.len).write(val); } self.len += 1; } @@ -1390,7 +1394,7 @@ pub fn merge_sort( // SAFETY: buf_ptr needs to be valid and len invariant upheld. unsafe { // the place we are taking from. - let ptr = self.buf_ptr.add(index); + let ptr = self.buf_ptr.as_ptr().add(index); // Shift everything down to fill in that spot. ptr::copy(ptr.add(1), ptr, self.len - index - 1); @@ -1400,7 +1404,7 @@ pub fn merge_sort( fn as_slice(&self) -> &[TimSortRun] { // SAFETY: Safe as long as buf_ptr is valid and len invariant was upheld. - unsafe { &*ptr::slice_from_raw_parts(self.buf_ptr, self.len) } + unsafe { &*ptr::slice_from_raw_parts(self.buf_ptr.as_ptr(), self.len) } } fn len(&self) -> usize { @@ -1419,7 +1423,7 @@ pub fn merge_sort( if index < self.len { // SAFETY: buf_ptr and len invariant must be upheld. unsafe { - return &*(self.buf_ptr.add(index)); + return &*(self.buf_ptr.as_ptr().add(index)); } } @@ -1436,7 +1440,7 @@ pub fn merge_sort( if index < self.len { // SAFETY: buf_ptr and len invariant must be upheld. unsafe { - return &mut *(self.buf_ptr.add(index)); + return &mut *(self.buf_ptr.as_ptr().add(index)); } } @@ -1452,7 +1456,7 @@ pub fn merge_sort( fn drop(&mut self) { // As long as TimSortRun is Copy we don't need to drop them individually but just the // whole allocation. - (self.run_dealloc_fn)(self.buf_ptr, self.capacity); + (self.run_dealloc_fn)(self.buf_ptr.as_ptr(), self.capacity); } } } From a3065a1a34fe1c0b85bdf3ff1f3d0bd470235e6b Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Sun, 22 Jan 2023 11:55:35 +0100 Subject: [PATCH 2/4] Unify insertion sort implementations Avoid duplicate insertion sort implementations. Optimize implementations. --- library/core/src/slice/sort.rs | 359 +++++++++++++++++---------------- 1 file changed, 188 insertions(+), 171 deletions(-) diff --git a/library/core/src/slice/sort.rs b/library/core/src/slice/sort.rs index 7f8895b150fe7..6bb53b16e6100 100644 --- a/library/core/src/slice/sort.rs +++ b/library/core/src/slice/sort.rs @@ -13,115 +13,178 @@ use crate::cmp; use crate::mem::{self, MaybeUninit, SizedTypeProperties}; use crate::ptr; -/// When dropped, copies from `src` into `dest`. -struct CopyOnDrop { +// When dropped, copies from `src` into `dest`. +struct InsertionHole { src: *const T, dest: *mut T, } -impl Drop for CopyOnDrop { +impl Drop for InsertionHole { fn drop(&mut self) { - // SAFETY: This is a helper class. - // Please refer to its usage for correctness. - // Namely, one must be sure that `src` and `dst` does not overlap as required by `ptr::copy_nonoverlapping`. unsafe { ptr::copy_nonoverlapping(self.src, self.dest, 1); } } } -/// Shifts the first element to the right until it encounters a greater or equal element. -fn shift_head(v: &mut [T], is_less: &mut F) +/// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]` +/// becomes sorted. +unsafe fn insert_tail(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - let len = v.len(); - // SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a - // pointer) and copying memory (`ptr::copy_nonoverlapping`). - // - // a. Indexing: - // 1. We checked the size of the array to >=2. - // 2. All the indexing that we will do is always between {0 <= index < len} at most. - // - // b. Memory copying - // 1. We are obtaining pointers to references which are guaranteed to be valid. - // 2. They cannot overlap because we obtain pointers to difference indices of the slice. - // Namely, `i` and `i-1`. - // 3. If the slice is properly aligned, the elements are properly aligned. - // It is the caller's responsibility to make sure the slice is properly aligned. - // - // See comments below for further detail. + debug_assert!(v.len() >= 2); + + let arr_ptr = v.as_mut_ptr(); + let i = v.len() - 1; + + // SAFETY: caller must ensure v is at least len 2. unsafe { - // If the first two elements are out-of-order... - if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) { - // Read the first element into a stack-allocated variable. If a following comparison - // operation panics, `hole` will get dropped and automatically write the element back - // into the slice. - let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0))); - let v = v.as_mut_ptr(); - let mut hole = CopyOnDrop { src: &*tmp, dest: v.add(1) }; - ptr::copy_nonoverlapping(v.add(1), v.add(0), 1); - - for i in 2..len { - if !is_less(&*v.add(i), &*tmp) { + // See insert_head which talks about why this approach is beneficial. + let i_ptr = arr_ptr.add(i); + + // It's important that we use i_ptr here. If this check is positive and we continue, + // We want to make sure that no other copy of the value was seen by is_less. + // Otherwise we would have to copy it back. + if is_less(&*i_ptr, &*i_ptr.sub(1)) { + // It's important, that we use tmp for comparison from now on. As it is the value that + // will be copied back. And notionally we could have created a divergence if we copy + // back the wrong value. + let tmp = mem::ManuallyDrop::new(ptr::read(i_ptr)); + // Intermediate state of the insertion process is always tracked by `hole`, which + // serves two purposes: + // 1. Protects integrity of `v` from panics in `is_less`. + // 2. Fills the remaining hole in `v` in the end. + // + // Panic safety: + // + // If `is_less` panics at any point during the process, `hole` will get dropped and + // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it + // initially held exactly once. + let mut hole = InsertionHole { src: &*tmp, dest: i_ptr.sub(1) }; + ptr::copy_nonoverlapping(hole.dest, i_ptr, 1); + + // SAFETY: We know i is at least 1. + for j in (0..(i - 1)).rev() { + let j_ptr = arr_ptr.add(j); + if !is_less(&*tmp, &*j_ptr) { break; } - // Move `i`-th element one place to the left, thus shifting the hole to the right. - ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1); - hole.dest = v.add(i); + ptr::copy_nonoverlapping(j_ptr, hole.dest, 1); + hole.dest = j_ptr; } // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. } } } -/// Shifts the last element to the left until it encounters a smaller or equal element. -fn shift_tail(v: &mut [T], is_less: &mut F) +/// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted. +/// +/// This is the integral subroutine of insertion sort. +unsafe fn insert_head(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool, { - let len = v.len(); - // SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a - // pointer) and copying memory (`ptr::copy_nonoverlapping`). - // - // a. Indexing: - // 1. We checked the size of the array to >= 2. - // 2. All the indexing that we will do is always between `0 <= index < len-1` at most. - // - // b. Memory copying - // 1. We are obtaining pointers to references which are guaranteed to be valid. - // 2. They cannot overlap because we obtain pointers to difference indices of the slice. - // Namely, `i` and `i+1`. - // 3. If the slice is properly aligned, the elements are properly aligned. - // It is the caller's responsibility to make sure the slice is properly aligned. - // - // See comments below for further detail. + debug_assert!(v.len() >= 2); + unsafe { - // If the last two elements are out-of-order... - if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) { - // Read the last element into a stack-allocated variable. If a following comparison - // operation panics, `hole` will get dropped and automatically write the element back - // into the slice. - let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1))); - let v = v.as_mut_ptr(); - let mut hole = CopyOnDrop { src: &*tmp, dest: v.add(len - 2) }; - ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1); - - for i in (0..len - 2).rev() { - if !is_less(&*tmp, &*v.add(i)) { + if is_less(v.get_unchecked(1), v.get_unchecked(0)) { + let arr_ptr = v.as_mut_ptr(); + + // There are three ways to implement insertion here: + // + // 1. Swap adjacent elements until the first one gets to its final destination. + // However, this way we copy data around more than is necessary. If elements are big + // structures (costly to copy), this method will be slow. + // + // 2. Iterate until the right place for the first element is found. Then shift the + // elements succeeding it to make room for it and finally place it into the + // remaining hole. This is a good method. + // + // 3. Copy the first element into a temporary variable. Iterate until the right place + // for it is found. As we go along, copy every traversed element into the slot + // preceding it. Finally, copy data from the temporary variable into the remaining + // hole. This method is very good. Benchmarks demonstrated slightly better + // performance than with the 2nd method. + // + // All methods were benchmarked, and the 3rd showed best results. So we chose that one. + let tmp = mem::ManuallyDrop::new(ptr::read(arr_ptr)); + + // Intermediate state of the insertion process is always tracked by `hole`, which + // serves two purposes: + // 1. Protects integrity of `v` from panics in `is_less`. + // 2. Fills the remaining hole in `v` in the end. + // + // Panic safety: + // + // If `is_less` panics at any point during the process, `hole` will get dropped and + // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it + // initially held exactly once. + let mut hole = InsertionHole { src: &*tmp, dest: arr_ptr.add(1) }; + ptr::copy_nonoverlapping(arr_ptr.add(1), arr_ptr.add(0), 1); + + for i in 2..v.len() { + if !is_less(&v.get_unchecked(i), &*tmp) { break; } - - // Move `i`-th element one place to the right, thus shifting the hole to the left. - ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1); - hole.dest = v.add(i); + ptr::copy_nonoverlapping(arr_ptr.add(i), arr_ptr.add(i - 1), 1); + hole.dest = arr_ptr.add(i); } // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. } } } +/// Sort `v` assuming `v[..offset]` is already sorted. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. Even improving performance in some cases. +#[inline(never)] +fn insertion_sort_shift_left(v: &mut [T], offset: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + // Using assert here improves performance. + assert!(offset != 0 && offset <= len); + + // Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted. + for i in offset..len { + // SAFETY: we tested that `offset` must be at least 1, so this loop is only entered if len + // >= 2. + unsafe { + insert_tail(&mut v[..=i], is_less); + } + } +} + +/// Sort `v` assuming `v[offset..]` is already sorted. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. Even improving performance in some cases. +#[inline(never)] +fn insertion_sort_shift_right(v: &mut [T], offset: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + // Using assert here improves performance. + assert!(offset != 0 && offset <= len && len >= 2); + + // Shift each element of the unsorted region v[..i] as far left as is needed to make v sorted. + for i in (0..offset).rev() { + // We ensured that the slice length is always at least 2 long. + // We know that start_found will be at least one less than end, + // and the range is exclusive. Which gives us i always <= (end - 2). + unsafe { + insert_head(&mut v[i..len], is_less); + } + } +} + /// Partially sorts a slice by shifting several out-of-order elements around. /// /// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case. @@ -161,26 +224,19 @@ where // Swap the found pair of elements. This puts them in correct order. v.swap(i - 1, i); - // Shift the smaller element to the left. - shift_tail(&mut v[..i], is_less); - // Shift the greater element to the right. - shift_head(&mut v[i..], is_less); + if i >= 2 { + // Shift the smaller element to the left. + insertion_sort_shift_left(&mut v[..i], i - 1, is_less); + + // Shift the greater element to the right. + insertion_sort_shift_right(&mut v[..i], 1, is_less); + } } // Didn't manage to sort the slice in the limited number of steps. false } -/// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case. -fn insertion_sort(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - for i in 1..v.len() { - shift_tail(&mut v[..i + 1], is_less); - } -} - /// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case. #[cold] #[unstable(feature = "sort_internals", reason = "internal to sort module", issue = "none")] @@ -507,7 +563,7 @@ where // SAFETY: `pivot` is a reference to the first element of `v`, so `ptr::read` is safe. let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); - let _pivot_guard = CopyOnDrop { src: &*tmp, dest: pivot }; + let _pivot_guard = InsertionHole { src: &*tmp, dest: pivot }; let pivot = &*tmp; // Find the first pair of out-of-order elements. @@ -560,7 +616,7 @@ where // operation panics, the pivot will be automatically written back into the slice. // SAFETY: The pointer here is valid because it is obtained from a reference to a slice. let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) }); - let _pivot_guard = CopyOnDrop { src: &*tmp, dest: pivot }; + let _pivot_guard = InsertionHole { src: &*tmp, dest: pivot }; let pivot = &*tmp; // Now partition the slice. @@ -742,7 +798,9 @@ where // Very short slices get sorted using insertion sort. if len <= MAX_INSERTION { - insertion_sort(v, is_less); + if len >= 2 { + insertion_sort_shift_left(v, 1, is_less); + } return; } @@ -844,10 +902,14 @@ fn partition_at_index_loop<'a, T, F>( let mut was_balanced = true; loop { + let len = v.len(); + // For slices of up to this length it's probably faster to simply sort them. const MAX_INSERTION: usize = 10; - if v.len() <= MAX_INSERTION { - insertion_sort(v, is_less); + if len <= MAX_INSERTION { + if len >= 2 { + insertion_sort_shift_left(v, 1, is_less); + } return; } @@ -887,7 +949,7 @@ fn partition_at_index_loop<'a, T, F>( } let (mid, _) = partition(v, pivot, is_less); - was_balanced = cmp::min(mid, v.len() - mid) >= v.len() / 8; + was_balanced = cmp::min(mid, len - mid) >= len / 8; // Split the slice into `left`, `pivot`, and `right`. let (left, right) = v.split_at_mut(mid); @@ -954,75 +1016,6 @@ where (left, pivot, right) } -/// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted. -/// -/// This is the integral subroutine of insertion sort. -fn insert_head(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - if v.len() >= 2 && is_less(&v[1], &v[0]) { - // SAFETY: Copy tmp back even if panic, and ensure unique observation. - unsafe { - // There are three ways to implement insertion here: - // - // 1. Swap adjacent elements until the first one gets to its final destination. - // However, this way we copy data around more than is necessary. If elements are big - // structures (costly to copy), this method will be slow. - // - // 2. Iterate until the right place for the first element is found. Then shift the - // elements succeeding it to make room for it and finally place it into the - // remaining hole. This is a good method. - // - // 3. Copy the first element into a temporary variable. Iterate until the right place - // for it is found. As we go along, copy every traversed element into the slot - // preceding it. Finally, copy data from the temporary variable into the remaining - // hole. This method is very good. Benchmarks demonstrated slightly better - // performance than with the 2nd method. - // - // All methods were benchmarked, and the 3rd showed best results. So we chose that one. - let tmp = mem::ManuallyDrop::new(ptr::read(&v[0])); - - // Intermediate state of the insertion process is always tracked by `hole`, which - // serves two purposes: - // 1. Protects integrity of `v` from panics in `is_less`. - // 2. Fills the remaining hole in `v` in the end. - // - // Panic safety: - // - // If `is_less` panics at any point during the process, `hole` will get dropped and - // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it - // initially held exactly once. - let mut hole = InsertionHole { src: &*tmp, dest: &mut v[1] }; - ptr::copy_nonoverlapping(&v[1], &mut v[0], 1); - - for i in 2..v.len() { - if !is_less(&v[i], &*tmp) { - break; - } - ptr::copy_nonoverlapping(&v[i], &mut v[i - 1], 1); - hole.dest = &mut v[i]; - } - // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. - } - } - - // When dropped, copies from `src` into `dest`. - struct InsertionHole { - src: *const T, - dest: *mut T, - } - - impl Drop for InsertionHole { - fn drop(&mut self) { - // SAFETY: The caller must ensure that src and dest are correctly set. - unsafe { - ptr::copy_nonoverlapping(self.src, self.dest, 1); - } - } - } -} - /// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and /// stores the result into `v[..]`. /// @@ -1180,8 +1173,6 @@ pub fn merge_sort( { // Slices of up to this length get sorted using insertion sort. const MAX_INSERTION: usize = 20; - // Very short runs are extended using insertion sort to span at least this many elements. - const MIN_RUN: usize = 10; // The caller should have already checked that. debug_assert!(!T::IS_ZST); @@ -1191,9 +1182,7 @@ pub fn merge_sort( // Short arrays get sorted in-place via insertion sort to avoid allocations. if len <= MAX_INSERTION { if len >= 2 { - for i in (0..len - 1).rev() { - insert_head(&mut v[i..], is_less); - } + insertion_sort_shift_left(v, 1, is_less); } return; } @@ -1236,10 +1225,7 @@ pub fn merge_sort( // Insert some more elements into the run if it's too short. Insertion sort is faster than // merge sort on short sequences, so this significantly improves performance. - while start > 0 && end - start < MIN_RUN { - start -= 1; - insert_head(&mut v[start..end], is_less); - } + start = provide_sorted_batch(v, start, end, is_less); // Push this run onto the stack. runs.push(TimSortRun { start, len: end - start }); @@ -1467,3 +1453,34 @@ pub struct TimSortRun { len: usize, start: usize, } + +/// Takes a range as denoted by start and end, that is already sorted and extends it to the left if +/// necessary with sorts optimized for smaller ranges such as insertion sort. +#[cfg(not(no_global_oom_handling))] +fn provide_sorted_batch(v: &mut [T], mut start: usize, end: usize, is_less: &mut F) -> usize +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(end > start); + + // This value is a balance between least comparisons and best performance, as + // influenced by for example cache locality. + const MIN_INSERTION_RUN: usize = 10; + + // Insert some more elements into the run if it's too short. Insertion sort is faster than + // merge sort on short sequences, so this significantly improves performance. + let start_found = start; + let start_end_diff = end - start; + + if start_end_diff < MIN_INSERTION_RUN && start != 0 { + // v[start_found..end] are elements that are already sorted in the input. We want to extend + // the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is + // more efficient that trying to push those already sorted elements to the left. + + start = if end >= MIN_INSERTION_RUN { end - MIN_INSERTION_RUN } else { 0 }; + + insertion_sort_shift_right(&mut v[start..end], start_found - start, is_less); + } + + start +} From f297afa0c91243b17283be17864f2c48f91127d9 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Sun, 22 Jan 2023 12:01:06 +0100 Subject: [PATCH 3/4] Flip scanning direction of stable sort Memory pre-fetching prefers forward scanning vs backwards scanning, and the code-gen is usually better. For the most sensitive types such as integers, these are planned to be merged bidirectionally at once. So there is no benefit in scanning backwards. The largest perf gains are seen for full ascending and descending inputs, which see 1.5x speedups. Random inputs benefit too, and some patterns can loose out, but these losses are minimal. --- library/core/src/slice/sort.rs | 112 ++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 45 deletions(-) diff --git a/library/core/src/slice/sort.rs b/library/core/src/slice/sort.rs index 6bb53b16e6100..227db51a0b403 100644 --- a/library/core/src/slice/sort.rs +++ b/library/core/src/slice/sort.rs @@ -1196,52 +1196,37 @@ pub fn merge_sort( let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn); - // In order to identify natural runs in `v`, we traverse it backwards. That might seem like a - // strange decision, but consider the fact that merges more often go in the opposite direction - // (forwards). According to benchmarks, merging forwards is slightly faster than merging - // backwards. To conclude, identifying runs by traversing backwards improves performance. - let mut end = len; - while end > 0 { - // Find the next natural run, and reverse it if it's strictly descending. - let mut start = end - 1; - if start > 0 { - start -= 1; - - // SAFETY: The v.get_unchecked must be fed with correct inbound indicies. - unsafe { - if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) { - while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) { - start -= 1; - } - v[start..end].reverse(); - } else { - while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) - { - start -= 1; - } - } - } + let mut end = 0; + let mut start = 0; + + // Scan forward. Memory pre-fetching prefers forward scanning vs backwards scanning, and the + // code-gen is usually better. For the most sensitive types such as integers, these are merged + // bidirectionally at once. So there is no benefit in scanning backwards. + while end < len { + let (streak_end, was_reversed) = find_streak(&v[start..], is_less); + end += streak_end; + if was_reversed { + v[start..end].reverse(); } // Insert some more elements into the run if it's too short. Insertion sort is faster than // merge sort on short sequences, so this significantly improves performance. - start = provide_sorted_batch(v, start, end, is_less); + end = provide_sorted_batch(v, start, end, is_less); // Push this run onto the stack. runs.push(TimSortRun { start, len: end - start }); - end = start; + start = end; // Merge some pairs of adjacent runs to satisfy the invariants. - while let Some(r) = collapse(runs.as_slice()) { - let left = runs[r + 1]; - let right = runs[r]; - // SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and - // neither side may be on length 0. + while let Some(r) = collapse(runs.as_slice(), len) { + let left = runs[r]; + let right = runs[r + 1]; + let merge_slice = &mut v[left.start..right.start + right.len]; unsafe { - merge(&mut v[left.start..right.start + right.len], left.len, buf_ptr, is_less); + merge(merge_slice, left.len, buf_ptr, is_less); } - runs[r] = TimSortRun { start: left.start, len: left.len + right.len }; - runs.remove(r + 1); + runs[r + 1] = TimSortRun { start: left.start, len: left.len + right.len }; + runs.remove(r); } } @@ -1263,10 +1248,10 @@ pub fn merge_sort( // run starts at index 0, it will always demand a merge operation until the stack is fully // collapsed, in order to complete the sort. #[inline] - fn collapse(runs: &[TimSortRun]) -> Option { + fn collapse(runs: &[TimSortRun], stop: usize) -> Option { let n = runs.len(); if n >= 2 - && (runs[n - 1].start == 0 + && (runs[n - 1].start + runs[n - 1].len == stop || runs[n - 2].len <= runs[n - 1].len || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) @@ -1454,14 +1439,15 @@ pub struct TimSortRun { start: usize, } -/// Takes a range as denoted by start and end, that is already sorted and extends it to the left if +/// Takes a range as denoted by start and end, that is already sorted and extends it to the right if /// necessary with sorts optimized for smaller ranges such as insertion sort. #[cfg(not(no_global_oom_handling))] -fn provide_sorted_batch(v: &mut [T], mut start: usize, end: usize, is_less: &mut F) -> usize +fn provide_sorted_batch(v: &mut [T], start: usize, mut end: usize, is_less: &mut F) -> usize where F: FnMut(&T, &T) -> bool, { - debug_assert!(end > start); + let len = v.len(); + assert!(end >= start && end <= len); // This value is a balance between least comparisons and best performance, as // influenced by for example cache locality. @@ -1469,18 +1455,54 @@ where // Insert some more elements into the run if it's too short. Insertion sort is faster than // merge sort on short sequences, so this significantly improves performance. - let start_found = start; let start_end_diff = end - start; - if start_end_diff < MIN_INSERTION_RUN && start != 0 { + if start_end_diff < MIN_INSERTION_RUN && end < len { // v[start_found..end] are elements that are already sorted in the input. We want to extend // the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is // more efficient that trying to push those already sorted elements to the left. + end = cmp::min(start + MIN_INSERTION_RUN, len); + let presorted_start = cmp::max(start_end_diff, 1); - start = if end >= MIN_INSERTION_RUN { end - MIN_INSERTION_RUN } else { 0 }; + insertion_sort_shift_left(&mut v[start..end], presorted_start, is_less); + } - insertion_sort_shift_right(&mut v[start..end], start_found - start, is_less); + end +} + +/// Finds a streak of presorted elements starting at the beginning of the slice. Returns the first +/// value that is not part of said streak, and a bool denoting wether the streak was reversed. +/// Streaks can be increasing or decreasing. +fn find_streak(v: &[T], is_less: &mut F) -> (usize, bool) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + if len < 2 { + return (len, false); } - start + let mut end = 2; + + // SAFETY: See below specific. + unsafe { + // SAFETY: We checked that len >= 2, so 0 and 1 are valid indices. + let assume_reverse = is_less(v.get_unchecked(1), v.get_unchecked(0)); + + // SAFETY: We know end >= 2 and check end < len. + // From that follows that accessing v at end and end - 1 is safe. + if assume_reverse { + while end < len && is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { + end += 1; + } + + (end, true) + } else { + while end < len && !is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { + end += 1; + } + (end, false) + } + } } From 5eff2645335e86f714a92a592f81936fead1f6a4 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Mon, 23 Jan 2023 09:12:25 +0100 Subject: [PATCH 4/4] Document missing unsafe blocks --- library/core/src/slice/sort.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/library/core/src/slice/sort.rs b/library/core/src/slice/sort.rs index 227db51a0b403..fc35c46d58300 100644 --- a/library/core/src/slice/sort.rs +++ b/library/core/src/slice/sort.rs @@ -21,6 +21,9 @@ struct InsertionHole { impl Drop for InsertionHole { fn drop(&mut self) { + // SAFETY: This is a helper class. Please refer to its usage for correctness. Namely, one + // must be sure that `src` and `dst` does not overlap as required by + // `ptr::copy_nonoverlapping` and are both valid for writes. unsafe { ptr::copy_nonoverlapping(self.src, self.dest, 1); } @@ -88,6 +91,7 @@ where { debug_assert!(v.len() >= 2); + // SAFETY: caller must ensure v is at least len 2. unsafe { if is_less(v.get_unchecked(1), v.get_unchecked(0)) { let arr_ptr = v.as_mut_ptr(); @@ -153,7 +157,8 @@ where // Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted. for i in offset..len { // SAFETY: we tested that `offset` must be at least 1, so this loop is only entered if len - // >= 2. + // >= 2. The range is exclusive and we know `i` must be at least 1 so this slice has at + // >least len 2. unsafe { insert_tail(&mut v[..=i], is_less); } @@ -176,9 +181,10 @@ where // Shift each element of the unsorted region v[..i] as far left as is needed to make v sorted. for i in (0..offset).rev() { - // We ensured that the slice length is always at least 2 long. - // We know that start_found will be at least one less than end, - // and the range is exclusive. Which gives us i always <= (end - 2). + // SAFETY: we tested that `offset` must be at least 1, so this loop is only entered if len + // >= 2.We ensured that the slice length is always at least 2 long. We know that start_found + // will be at least one less than end, and the range is exclusive. Which gives us i always + // <= (end - 2). unsafe { insert_head(&mut v[i..len], is_less); } @@ -1222,6 +1228,8 @@ pub fn merge_sort( let left = runs[r]; let right = runs[r + 1]; let merge_slice = &mut v[left.start..right.start + right.len]; + // SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and + // neither side may be on length 0. unsafe { merge(merge_slice, left.len, buf_ptr, is_less); }