diff --git a/rust/src/main.rs b/rust/src/main.rs index 9372a20..a0817d6 100644 --- a/rust/src/main.rs +++ b/rust/src/main.rs @@ -3,23 +3,19 @@ use rayon::join; fn choose_pivot(slice: &[T]) -> usize { - // if slice.len() <= 2 {return slice.len() - 1;}; - let (mut ismall, imid, mut ibig) = (0, slice.len() / 2, slice.len() - 1); - if slice[ibig] < slice[ismall] { - std::mem::swap(&mut ibig, &mut ismall); + let (mut left, mid, mut right) = (0, slice.len() / 2, slice.len() - 1); + if slice[right] < slice[left] { + std::mem::swap(&mut right, &mut left); } - if slice[imid] <= slice[ismall] { - ismall - } else if slice[ibig] <= slice[imid] { - ibig + if slice[mid] <= slice[left] { + return left; + } else if slice[right] <= slice[mid] { + return right; } else { - imid + return mid; } } -/// choose a pivot, then reorder so that everything to the left of the pivot is smaller, and -/// everything to the right is greater -/// Assumes slice.len() > 2 fn partition(slice: &mut [T], pivot: usize) -> usize { let mxix = slice.len() - 1; slice.swap(pivot, mxix); @@ -52,10 +48,7 @@ fn partition(slice: &mut [T], pivot: usize) -> usize { return left + 1; } - panic!( - "This should be unreachable. Indices: {}, {} / {}", - left, right, mxix - ); + panic!("partition failed.") } fn quicksort(slice: &mut [T]) { @@ -68,11 +61,10 @@ fn quicksort(slice: &mut [T]) { return; } - let pivot = choose_pivot(slice); - let pivot = partition(slice, pivot); + let pivot = partition(slice, choose_pivot(slice)); let (left_slice, right_slice) = slice.split_at_mut(pivot); - let right_slice = &mut right_slice[1..]; + let right_slice = &mut right_slice[1..]; // want to exclude pivot quicksort(left_slice); quicksort(right_slice); @@ -88,12 +80,9 @@ fn par_quicksort(slice: &mut [T]) { return; } - let pivot = choose_pivot(slice); - let pivot = partition(slice, pivot); + let pivot = partition(slice, choose_pivot(slice)); let (left_slice, right_slice) = slice.split_at_mut(pivot); - // left_slice is [0 - pivot-1], right_slice is [pivot, end]. We don't want to include the - // pivot, so reassign right_slice - let right_slice = &mut right_slice[1..]; + let right_slice = &mut right_slice[1..]; // want to exclude pivot join(|| quicksort(left_slice), || quicksort(right_slice)); } @@ -132,3 +121,17 @@ fn qs(b: &mut Bencher) { b.iter(|| quicksort(&mut test_vec)) } + +#[test] +fn qs_comp() { + let bv = get_bench_vec(); + let mut par = bv.clone(); + let mut basic = bv.clone(); + + par_quicksort(&mut par); + quicksort(&mut basic); + + for (l, r) in par.iter().zip(basic.iter()) { + assert_eq!(l, r) + } +}