Menu

Quick(er)sort, Parallelism, and Fork/Join in JDK7

2011-05-27

Most men age twenty-six can boast few accomplishments. Sir Charles Antony Richard Hoare is not most men. At that age, C.A.R. Hoare invented Quicksort (pdf).

Quicksort is typically more efficient than other comparison sorting algorithms in both auxiliary memory requirements and execution cycles. Yet, it shares the same common case time complexity limit of its peers - O(n*log n).

Parallelism is the tool to effectively breach that barrier. As a binary tree sort, Quicksort is especially suited for parallelism since multiple threads of execution can work on sorting task parts without synchronized data access.

We will detail two effective synchronization policies for parallel Quicksort in Java. One is usable now, one is coming soon in JDK 7.

  1. shared count down latch
  2. Fork/Join framework, new in JDK 7 (JSR 166)

The code here is condensed for rapid comprehension. To see more proper implementations of the concepts, visit https://github.com/pmbauer/parallel.

For a quick refresher on the serial version of Quicksort, the Algolist Quicksort article is excellent; its explanation and visuals are clearer than C.A.R. Hoare’s Oxford journal paper. All parallel implementations use in-place partitioning.

naive first take

First, a naive implementation to motivate what follows.

When writing a parallel algorithm, identify your synchronization policy early on. This answers the question, “How does the caller know when we’re done?”

The simplest synchronization policy we have is the JVM call stack - straight serial programming. To model that with threads, recursively spawn a new thread to work on every new partition made. Then, wait for that thread to finish. Here’s how that might look:

class NaiveParallelQuickSort implements Runnable {
    private int[] a;
    private int left;
    private int right;

    public static void sort(int[] a) {
        Thread root = new Thread(new NaiveParallelQuickSort(a, 0, a.length - 1));
        root.start();
        try {
            root.join();
        } catch (InterruptedException e) {
        }
    }

    public NaiveParallelQuickSort(int[] a, int left, int right) {
        this.a = a;
        this.left = right;
        this.left = right;
    }

    public void run() {
        int pivot = 0;

        if (left < right) {
            pivot = partition(a, left, right);

            Thread t1 = new Thread(new NaiveParallelQuickSort(a, left, pivot));
            Thread t2 = new Thread(new NaiveParallelQuickSort(a, pivot + 1, right));

            t1.start(); t2.start();

            try { // wait for the results
                t1.join(); t2.join();
            } catch (InterruptedException e) {
                return;
            }
        }
    }

    private static int partition(int[] a, int i, int j) {
        // see implementation in http://goo.gl/T46KI (github)
    }
}

Break out the Fail Whale. Spawning threads is expensive - orders of magnitude more expensive than function calls - and we’re spawning 2*n threads here! Even with an infinite number of cores, the naive approach is a great way to model glacial movement. We need a way to manage and limit all these threads running around.

executor service

Thanks to Doug Lea, Java features robust concurrency primitives: blocking queues, thread-safe collections, assorted locks, atomic values, and thread pools. The latter implements the ExecutorService interface. Just like Thread, an ExecutorService executes Runnables, so a few minor modifications to the naive attempt:

class ParallelQuicksortTake2 implements Runnable {
    private static ExecutorService pool =
        Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
    ...
    public void run() {
        int pivot = 0;

        if (left < right) {
            pivot = partition(a, left, right);

            pool.execute(new ParallelQuickSortTake2(a, left, pivot));
            pool.execute(new ParallelQuickSortTake2(a, pivot + 1, right));
        }
    }
}

Deceptively simple, our second take yet has shortcomings that render it unusable.

  1. The overhead for execute is still quite greater than recursive method calls. Internally, execute pushes the Runnable onto a shared blocking queue. Threads then pull from the queue, creating a single synchronization point of contention.
  2. Creating a new ParallelQuickSortTake2 object for every partition causes ~6*n auxiliary memory overhead (2 * n * (2 ints + 1 array reference)). That much object creation will thrash the garbage collector in a hurry.
  3. The execute call is asynchronous, so we have no way to let the calling process know when the work is done.

1 and 2 are problems of communication overhead.

granularity

Communication overhead makes parallelism have diminishing - usually negative - returns below a certain task size. The simple fix is to serially execute a task if it is smaller than a certain threshold.

But, finding the optimal granularity is non-trivial since it can fluctuate greatly between machines and inputs. One approach is to instrument our algorithm with live profiling and a feedback loop. That’s another article.

As a simplifying alternative, we can set a hard threshold. Via experimentation, I found 0x1000 (4096) is roughly optimal on my quad-core machine.

class ParallelQuicksortTake3 implements Runnable {
    private static SERIAL_THRESHOLD = 0x1000;
    ...
    public void run() {
        int pivot = 0;

        if (right - left < SERIAL_THRESHOLD)
            Arrays.sort(a, left, right + 1);
        else {
            pivot = partition(a, left, right);

            pool.execute(new ParallelQuickSortTake3(a, left, pivot));
            pool.execute(new ParallelQuickSortTake3(a, pivot + 1, right));
        }
    }
}

The key wins here:

That leaves the problem of notifying the invoking process when the sort is complete.

count-down latch (approach 1)

A count down latch is a synchronization primitive that allows one or more threads to wait on a set of tasks being performed in other threads. As each bit of work is completed, workers decrement the count. Once the count hits zero, all waiting threads are signaled, notifying the work is done.

For our case, the size of the work and the initial value of our count-down latch is equal to the size of the array being sorted. As each array value is placed in its sorted position, the count-down latch is decremented accordingly. When all values are in their sorted position, the count-down latch hits zero and the invoking process is signaled.

Lucky for us, JDK 5 introduced a CountDownLatch. Unlucky for us, it only exposes a single decrementing function, countDown(), that decrements the count by one. Our algorithm needs to efficiently decrement the count in increments of SERIAL_THRESHOLD or less.

Fortunately, such a count-down latch is easy to write and I provide one here: CountDownLatch.java

Armed with our custom latch, we can complete our first approach. Here is a (highly) condensed listing (original here: LatchQuicksortTask.java)

 1public class LatchQuicksortTask implements Runnable {
 2    private static final int SERIAL_THRESHOLD = 0x1000;
 3
 4    // Defines bounded region (sub-array) of array to sort.
 5    private static class QuicksortSubTask implements Runnable {
 6        private final int left;
 7        private final int right;
 8        private final LatchQuicksortTask root;
 9
10        QuicksortSubTask(LatchQuicksortTask task) {...}
11        QuicksortSubTask(LatchQuicksortTask rootTask, int left, int right) {...}
12
13        public void run() {
14            int pivotIndex = root.partitionOrSort(left, right);
15
16            if (pivotIndex >= 0) {
17                if (left < pivotIndex)
18                    root.pool.execute(new QuicksortSubTask(root, left, pivotIndex));
19                if (pivotIndex + 1 < right)
20                    root.pool.execute(new QuicksortSubTask(root, pivotIndex + 1, right));
21            }
22        }
23
24    }
25
26    private final ExecutorService pool;
27    private final CountDownLatch latch;
28    private final int[] a;
29
30    public LatchQuicksortTask(int[] a, ExecutorService threadPool) {
31        pool = threadPool;
32        this.a = a;
33        latch = new CountDownLatch(a.length);
34    }
35
36    public final void waitUntilSorted() throws InterruptedException { latch.await(); }
37
38    public void run() { pool.execute(new QuicksortSubTask(this)); }
39
40    // Dependent on size of subsection, partitions and returns pivot,
41    //   or sorts and returns -1
42    private int partitionOrSort(int left, int right) {
43        int pivotIndex = -1;
44        int sortedCount;
45
46        if (serialThresholdMet(left, right)) {
47            Arrays.sort(a, left, right + 1);
48            sortedCount = right - left + 1;
49        } else {
50            pivotIndex = partition(a, left, right);
51            sortedCount = countSortedBoundaryValues(left, right, pivotIndex);
52        }
53
54        latch.countDown(sortedCount);
55
56        return pivotIndex;
57    }
58
59    /*
60     * When left == pivotIndex, then a[left] is its sorted position.
61     * When right == pivotIndex + 1, then a[right] is in its sorted position.
62     *
63     * As long as SERIAL_THRESHOLD is guaranteed > 2, these two conditions are
64     * mutually exclusive.
65     * Therefore, we can gain a minor efficiency and avoid an extra branch for
66     * each case.
67     */
68    private int countSortedBoundaryValues(int left, int right, int pivotIndex) {
69        return (left == pivotIndex || right == pivotIndex + 1) ? 1 : 0;
70    }
71
72    private boolean serialThresholdMet(int left, int right) {
73        return right - left < SERIAL_THRESHOLD;
74    }
75}

[5,8,38] In addition to the array reference, each task needs references to the count-down latch and the thread pool. So, some refactoring was in order to encapsulate that trio tuple. LatchQuicksortTask holds a reference to the array, latch, and pool. LatchQuicksortSubTask does the actual work and holds a single reference to its root; the first sub-task is spawned in the root’s run() method.

[33,36] The root task encapsulates the count-down latch, and initializes the count to the size of the array to sort. The invoking process can block until the sort is done by calling waitUntilSorted().

[42] partitionOrSort(int,int) is called from the sub-task, thus from inside the pool.

[48] If we are at the threshold, we will sort a whole section at once. We calculate the size of the area sorted so we can decrement our latch later.

[51] In this case, we partition the sub-section per normal Quicksort. The nature of partition is such that if a poor pivot is selected, one partition might be very small. In the extreme case where the pivot is equal to a boundary, that value is already sorted and will not be included in the spawned sub-tasks [18,20]. So we must account for it or waiting processes will never wake. See the comment on countSortedBoundaryValues [68] for more details.

[54] Decrement the latch by count previously calculated according to whether we sorted or partitioned.

Armed with our definition, here’s an example usage in a sentence:

int[] a = someLargeRandomArray();
ExecutorService pool =
        Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
LatchQuicksortTask sortingTask = new LatchQuicksortTask(a, pool);

pool.execute(sortingTask);
sortingTask.waitUntilSorted();

On a quad-core machine, this algorithm nets 300% speedups over serial quicksort! So it works, what’s not to like?

  1. The execution model still has a serial bottleneck. Our thread pool executor service uses a single, shared blocking queue to schedule tasks. The algorithm will not scale linearly as we add cores and threads.
  2. Complexity. This approach goes a long way to obscure C.A.R. Hoare’s classic recursive algorithm. The count-down latch derivative scarcely resembles its ancestor.

We can do better.

fork/join (approach 2)

As of JDK 7, Java has a first-class Fork/Join framework, compliments of the immortal Doug Lea. A detailed explanation of the Fork/Join concepts and Java implementation is well deserving of its own post.

In brief, the shared-queue model is replaced with one dequeue per thread. Each thread pulls and pushes work at one end of its dequeue. When a thread runs out of work, it will round-robin steal work from the opposite end of other thread’s dequeues.

If that sounds complicated, it isn’t, at least conceptually. One paragraph just doesn’t do it justice, so I point you to the canonical paper by Doug Lea, A Java Fork/Join Framework. If you learn better visually, David Liebke presented some stunning visuals for Fork/Join at Clojure Conj 2010 (slide deck). But, if you really want a trip, download JDK 7 and check out the Fork/Join source. In implementation, there’s a lot of really cool moving parts to blow your mind.

For our purposes, Fork/Join addresses the two shortcomings of LatchQuicksortTask.

  1. As long as each thread in a Fork/Join pool is pushing/pulling work from the one end of its own dequeue, no synchronization is needed! Shared mutable state (and synchronization) only comes into play when threads steal work from each other.
  2. The Fork/Join framework automatically takes care of signaling threads that are waiting on tasks to finish, so we can write our algorithm in a style more akin to serial quicksort.

In JDK 7, Runnable is to ExecutorService as ForkJoinTask is to ForkJoinPool. ForkJoinTask is more general than we need, since each of our sorting tasks has no need to return a result. So we use a convenience class, RecursiveAction, that inherits from ForkJoinTask<Void>.

Here’s ForkJoinQuicksortTask (full source here: ForkJoinQuicksort.java).

 1public class ForkJoinQuicksortTask extends RecursiveAction {
 2    private static final int SERIAL_THRESHOLD = 0x1000;
 3
 4    private final int[] a;
 5    private final int left;
 6    private final int right;
 7
 8    public ForkJoinQuicksortTask(int[] a) { this(a, 0, a.length - 1); }
 9
10    private ForkJoinQuicksortTask(int[] a, int left, int right) {
11        this.a = a;
12        this.left = left;
13        this.right = right;
14    }
15
16    @Override
17    protected void compute() {
18        if (serialThresholdMet()) {
19            Arrays.sort(a, left, right + 1);
20        } else {
21            int pivotIndex = partition(a, left, right);
22            ForkJoinTask t1 = null;
23
24            if (left < pivotIndex)
25                t1 = new ForkJoinQuicksortTask(a, left, pivotIndex).fork();
26            if (pivotIndex + 1 < right)
27                new ForkJoinQuicksortTask(a, pivotIndex + 1, right).invoke();
28
29            if (t1 != null)
30                t1.join();
31        }
32    }
33
34    private boolean serialThresholdMet() { return right - left < SERIAL_THRESHOLD; }
35}

[1] Pulls in all that Fork/Join goodness. Yum yum.

[25] Pushes a new task onto the current thread’s dequeue for later execution. The current thread might execute it or another thread may steal it.

[27] Execute this task right away. That may (recursively) result in pushing other tasks onto the dequeue for later execution or stealing.

[30] If we spawned a fork [25] continue working on tasks until that fork completes.

And here’s how to use it:

int[] a = someLargeRandomArray();
ForkJoinPool pool = new ForkJoinPool(Runtime.getRuntime().availableProcessors());

pool.invoke(new ForkJoinQuicksortTask(a)); // blocks until a is sorted

Fork/Join makes this quite elegant. Fully parallel, minimal shared mutable state, and it even bears a striking resemblance to serial quicksort!

If you want to try this out on your own machine:

  1. download JDK 7
  2. Use git to clone my project: https://github.com/pmbauer/parallel/
  3. From the project root: mvn test

bibliography and further reading

Related tags:

email comments to paul@bauer.codes

site menu

Back to top