Java multi-threaded unit testing

With Java 5 and JUnit 4, writing multi-threaded unit tests has never been easier.

Let’s take a brain dead ID generator as our domain object example:

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is NOT thread-safe.
     * </p>
     */
    static class BrokenUniqueIdGenerator {
        private long counter = 0;

        public long nextId() {
            return ++counter;
        }
    }

This is how you test this class with different threads loads of 1, 2, 4, 8, 16, and 32 threads. We use the Java java.util.concurrent API to manage threads to use our domain object concurrently. We use JUnit @Test methods to run the test and verify results.

    @Test
    public void test01() throws InterruptedException, ExecutionException {
        test(1);
    }

    @Test
    public void test02() throws InterruptedException, ExecutionException {
        test(2);
    }

    @Test
    public void test04() throws InterruptedException, ExecutionException {
        test(4);
    }

    @Test
    public void test08() throws InterruptedException, ExecutionException {
        test(8);
    }

    @Test
    public void test16() throws InterruptedException, ExecutionException {
        test(16);
    }

    @Test
    public void test32() throws InterruptedException, ExecutionException {
        test(32);
    }

    private void test(final int threadCount) throws InterruptedException, ExecutionException {
        final BrokenUniqueIdGenerator domainObject = new BrokenUniqueIdGenerator();
        Callable<Long> task = new Callable<Long>() {
            @Override
            public Long call() {
                return domainObject.nextId();
            }
        };
        List<Callable<Long>> tasks = Collections.nCopies(threadCount, task);
        ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
        List<Future<Long>> futures = executorService.invokeAll(tasks);
        List<Long> resultList = new ArrayList<Long>(futures.size());
        // Check for exceptions
        for (Future<Long> future : futures) {
            // Throws an exception if an exception was thrown by the task.
            resultList.add(future.get());
        }
        // Validate the IDs
        Assert.assertEquals(threadCount, futures.size());
        List<Long> expectedList = new ArrayList<Long>(threadCount);
        for (long i = 1; i <= threadCount; i++) {
            expectedList.add(i);
        }
        Collections.sort(resultList);
        Assert.assertEquals(expectedList, resultList);
    }

Let’s walk through the test(int threadCount) method. We start by creating our domain object:

final BrokenUniqueIdGenerator domainObject = new BrokenUniqueIdGenerator();

This class has one method, nextId, which we wrap into a task, an instance of Callable:

        Callable<Long> task = new Callable<Long>() {
            @Override
            public Long call() {
                return domainObject.nextId();
            }
        };

This is just a generic way to fit our API call in the Java concurrency API.

We then make copies of this task, one for each thread:

List<Callable<Long>> tasks = Collections.nCopies(threadCount, task);

Next, we create a thread pool, sized at least as big as the number of threads we want to test, in this case we use the exact given value threadCount.

ExecutorService executorService = Executors.newFixedThreadPool(threadCount);

And ask Java to run all the tasks concurrently using threads from the pool:

List<Future<Long>> futures = executorService.invokeAll(tasks);

The call to invokeAll blocks until all the threads are done. Each task is run on a thread, which invokes the tasks’ call method, which in turn calls our domain object API, nextId().

When you run this test case, it will sometimes pass and sometimes fail.

That’s multithreaded testing with Java 5 and JUnit 4. Voila!

BTW, the proper implementation is:

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is thread-safe.
     * </p>
     */
    static class UniqueIdGenerator {
        private final AtomicLong counter = new AtomicLong();

        public long nextId() {
            return counter.incrementAndGet();
        }
    }

The full listing is:

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;

import org.junit.Assert;
import org.junit.Test;

public class MultiThreadedTestCase {

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is NOT thread-safe.
     * </p>
     */
    static class BrokenUniqueIdGenerator {
        private long counter = 0;

        public long nextId() {
            return ++counter;
        }
    }

    /**
     * Generates sequential unique IDs starting with 1, 2, 3, and so on.
     * <p>
     * This class is thread-safe.
     * </p>
     */
    static class UniqueIdGenerator {
        private final AtomicLong counter = new AtomicLong();

        public long nextId() {
            return counter.incrementAndGet();
        }
    }

    private void test(final int threadCount) throws InterruptedException, ExecutionException {
        final UniqueIdGenerator domainObject = new UniqueIdGenerator();
        Callable<Long> task = new Callable<Long>() {
            @Override
            public Long call() {
                return domainObject.nextId();
            }
        };
        List<Callable<Long>> tasks = Collections.nCopies(threadCount, task);
        ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
        List<Future<Long>> futures = executorService.invokeAll(tasks);
        List<Long> resultList = new ArrayList<Long>(futures.size());
        // Check for exceptions
        for (Future<Long> future : futures) {
            // Throws an exception if an exception was thrown by the task.
            resultList.add(future.get());
        }
        // Validate the IDs
        Assert.assertEquals(threadCount, futures.size());
        List<Long> expectedList = new ArrayList<Long>(threadCount);
        for (long i = 1; i <= threadCount; i++) {
            expectedList.add(i);
        }
        Collections.sort(resultList);
        Assert.assertEquals(expectedList, resultList);
    }

    @Test
    public void test01() throws InterruptedException, ExecutionException {
        test(1);
    }

    @Test
    public void test02() throws InterruptedException, ExecutionException {
        test(2);
    }

    @Test
    public void test04() throws InterruptedException, ExecutionException {
        test(4);
    }

    @Test
    public void test08() throws InterruptedException, ExecutionException {
        test(8);
    }

    @Test
    public void test16() throws InterruptedException, ExecutionException {
        test(16);
    }

    @Test
    public void test32() throws InterruptedException, ExecutionException {
        test(32);
    }
}

Note: I used Oracle Java 1.6.0_24 (64-bit) on Windows 7 (64-bit).

13 thoughts on “Java multi-threaded unit testing

  1. edgrip

    Nice trick about Callable usage, but this test suite is not reliable since there is a chance for it to pass with the BrokenUniqueIdGenerator…

    Like

    Reply
  2. Sander Verhagen

    This was of great help. Just a small detail:

    This should be the other way around:
    Assert.assertEquals(futures.size(), threadCount);

    Thus as follows:
    Assert.assertEquals(threadCount, futures.size());

    Like

    Reply
  3. Norbert

    It’s already a bit old, but just a short note. I have to add a small sleep in the nextId method (e.g. 10ms), otherwise both generators are okay. Nice work, btw!

    Like

    Reply

Leave a reply to Gary Gregory Cancel reply