Monday, January 22, 2018

Writing a Scheduler for Tasks with DAG Dependencies using CompletableFuture and Async I/O

I had to write a task scheduler in Java recently. The tasks had dependencies with each other and the dependencies formed a Directed Acyclic Graph (DAG) of nodes when viewed as a graph. Each task must forward it's results to the node that it is connected to, otherwise the next node can't start executing. For example, if we have a bunch of tasks with dependencies like so:

B -->A
C --> B
D  --> B, C
E -->  A, C, D
F --> A, E

Here, B needs A's results to start, C needs B's results to start, D needs both B's and C's results to start, E needs A, C, and D's results to start, and finally F needs A and E's results to start, after which the whole graph completes execution when F finishes. The final results from the graph is whatever F produces.

So, how do we design a scheduler (and executor) for such a scenario?

Well, here are some of the things I did/used to write the scheduler:

1. Create a graph of nodes where each node represent a task.

2. Make sure the graph is a DAG (Directed Acyclic Graph): i) There is no cycles in the graph, ii) there is at least one node which does not depend on any other node. We call all such nodes the root node(s). Execution starts with the root nodes. Save the indegrees of all nodes. The root nodes are nodes with an indegree of 0. We also save all the nodes whose outdegree is 0. We call these nodes leaf nodes. Our task scheduler will stop when all leaf nodes have been scheduled.

3. I used Java 8's CompletableFuture async API's to maximize parallel running performance and ease of development. If you haven't used them already, Java 8's CompletableFuture API's are full of awesomeness! If you have many tasks at hand to run in parallel they come packed with a rich feature set that allow you to do almost anything you need for a parallel/concurrent programming environment!

4. The key algorithm I used for maintaining the task dependencies is Topological Sorting. Topological sorting is a nifty algorithm that can order a set of nodes with dependencies among them, as long as the nodes do not form any cycle. By following a strategy similar to topological sorting, I was able to start executing tasks at the root node(s), and gradually advance the execution to the dependent nodes who were waiting for the results of the nodes they depended on. To traverse the DAG, I utilized callback functions in a clever way.

Let's look at the algorithm for executing tasks in a topologically sorted graph of tasks:

i) Create a blocking queue of tasks. At the start add all tasks with an indegree of 0 to the queue. We will use this queue to add tasks when all their dependencies are met. Note that we are using a blocking queue so that the task scheduler will wait when there is no tasks available to execute. We will quit when all leaf nodes have been scheduled for execution.

ii) Enter an infinite loop for executing tasks. This loop looks somewhat like this:


while (true) {
    // Waiting for a task to be available.
    // Blocking call to the task queue. Tasks are added to the queue only when they
    // have an indegree of 0.    TaskNode currentTask = taskQueue.take();

    CompletableFuture<T> taskFuture = scheduleTask(currentTask, taskQueue,
                                                  taskResults,
                                                  nodeIndegrees, functionProvider);
    futures.put(currentTask.getName(), taskFuture);

    if (this.leafNodes.contains(currentTask.getName())) {
        leafNodeScheduled++;
        if (leafNodeScheduled == this.leafNodes.size()) {
            // We've scheduled all leaf nodes in the task set. Done!
            break;
        }
    }
}

T is the type of the results that the tasks will produce. The scheduleTask() method actually schedules the tasks asynchronously:

private CompletableFuture<T> scheduleTask(final TaskNode task,
                                          final BlockingQueue<TaskNode> taskQueue,
                                          final Map<String, T> taskResults,
                                          final Map<String, Integer> nodeIndegrees,
                                          final FunctionProvider<T> functionProvider) {
    String taskName = task.getName();

    // Create a dummy future to start the task with desired inputs from all
    // dependencies (tasks that this task depends on)

    List<T> dependenciesRecords = new ArrayList<>();
    for (String dependencyName : task.getDependencies()) {
        dependenciesRecords.addAll(taskResults.getOrDefault(dependencyName, new ArrayList<>()));
    }
    CompletableFuture<T> startCf = CompletableFuture.supplyAsync(() -> dependenciesRecords);
    PostCompletionTask postCompletionTask = new PostCompletionTask(task, taskQueue, taskResults, nodeIndegrees);
    Function<T, T> function = functionProvider.newFunction(taskName);
    CompletableFuture<T> taskFuture = startCf
                                       .thenApplyAsync(function, this.dagExecutor)
                                       .thenApplyAsync(postCompletionTask::whenDone, this.dagExecutor);

    return taskFuture;
} 

We do a bunch of things to make sure the task starts with the results from all the tasks it depends on. Whenever a task completes, we save it's results in a Map, so that we can feed those results to the next task. The simplest async job using Java 8's CompletableFuture is:

CompletableFuture<String> future  = CompletableFuture.supplyAsync(() -> "Hello");

Here, we have a completable future that simply returns the string "Hello". We create a simple async task that simply returns the results of all the previous jobs that the current job depends on:

CompletableFuture<T> startCf = CompletableFuture.supplyAsync(() -> dependenciesRecords);

Then we start the current job which is described by some Function. We chain the current job with a post-completion callback task so that we can decrement the indegree of the next task and start the next task when the it's indegree reaches 0.

PostCompletionTask postCompletionTask = new PostCompletionTask(task, taskQueue, taskResults, nodeIndegrees);

Here is how the PostCompletionTask looks like:

private class PostCompletionTask {
    private final TaskNode taskNode;
    private final BlockingQueue<TaskNode> taskQueue;
    private final Map<String, T> taskResults;
    private final Map<String, Integer> nodeIndegrees;

    public PostCompletionTask(final TaskNode taskNode,
                              final BlockingQueue<TaskNode> taskQueue,
                              final Map<String, T> taskResults,
                              final Map<String, Integer> nodeIndegrees) {
        this.taskNode = taskNode;
        this.taskQueue = taskQueue;
        this.taskResults = taskResults;
        this.nodeIndegrees = nodeIndegrees;
    }

    public T whenDone(T t) {
        Collection<TaskNode> outNodes = edges.get(this.taskNode.getName());

        for (TaskNode outNode : outNodes) {
            synchronized (this.nodeIndegrees) {
                // We need to synchronize the indegree map copy as multiple threads
                // can execute the following block                // and may simultaneously see the value going to 0!

                int currCount = this.nodeIndegrees.getOrDefault(outNode.getName(), -1);
                this.nodeIndegrees.put(outNode.getName(), currCount - 1);
                if ((currCount - 1) == 0) {
                    this.taskQueue.offer(outNode);
                }
            }
        }

        this.taskResults.put(this.taskNode.getName(), t);
        // Return the original results of the task whose completion we're handling.
        return t;
    }
}

The main job of the completion callback (whenDone()) is to decrement the indegree of all the nodes that depends on the current node (whose completion callback we are executing). In doing so, if we find a node whose indegree became 0, we add that node to the blocking task queue to be picked up by our scheduler code for scheduling and execution. This step is one of the core steps in doing topological sorting. When the indegree of a node becomes 0, that node is up next for action!

Note that we also save the task's results in a global map so that we can look it up when we schedule the task which depends on this task. We optionally return the results from the callback in case we want to chain this post completion task to another async tasks.

Overall, it was a great way for me to learn about CompletableFuture async API's in Java 8. Applying graph algorithms on top made it even more enjoyable.

Unfortunately, this is code I wrote for work, so I can't really write the entire code of the scheduler. My goal here is to highlight the algorithms and tools used in the process without revealing any proprietary code.

Hope I was able to illustrate how one can write a task scheduler using Java's CompletableFuture API's. Please let me know if you have any questions.

Adios!


Tuesday, December 26, 2017

Consistent Hashing implementation using BST in Python

I was reading a bunch of papers and articles on Distributed Systems for my work recently when I stumbled into the concept of consistent hashing. I had some ideas about distributed hash tables (DHT) before, but I had never delved deeper into the subject. Out of curiosity I studied up some more on consistent hashing and was amazed to find such a simple algorithm with such extraordinary utility in distributed systems. The implementation is also pretty straightforward when you use an efficient container like the binary search tree (BST).

The idea

Consistent hashing is a hashing technique that maps keys to nodes with the assumption that nodes may join and leave the system at random. The defining feature is that when a new node joins or an existing node leaves the system, only a small set of key-to-node assignments need to change. If we have M keys and N nodes in our system, the expected number of key-to-node re-mapping is M/N in the event of a node addition or removal. This is a significant improvement from regular hashing, which would probably re-map almost all the keys when a node leaves or joins.

In traditional hashing we have a fixed number of bins or buckets where we place our keys in. In consistent hashing we remove this limitation of a fixed number of bins. Instead, we hash the keys to a virtually unlimited integer space and place our bins randomly throughout the same integer space. The bin that is closest to a hashed key in a clockwise direction is our target bin for the key!

Algorithm summary

Brief description of how consistent hashing works:
  1. All keys and nodes are mapped to the same integer space (typically between -2^64 and 2^64, or something like that).
  2. If we have N nodes, they are assigned IDs which are essentially the hash numbers of their names.
  3. If we have M keys, they hashed to the integer space where the nodes are already mapped to.
  4. If a key's hashed number matches a node's ID, then we trivially return the node [ID]. Otherwise, we find the next node ID greater than the key's hash value. If no such node ID were found within the range's positive end, we wrap around and return the node ID with the smallest value. Thus we basically form a ring of nodes in the system.
  5. When a new node is added to the system it is placed in the hash ring according to it's ID (which is the hash value of it's name). All nodes between it's ID and it's previous node's ID are then re-mapped to this node. So, the only node that is affected in the process is the node immediately after this new node in the ring. Some of the keys that would point to that next node will now point to the newly added node. The expected number of keys moved around is M/N.
  6. Similarly, when a node leaves the system, all the keys between the leaving node and the node immediately preceding it are re-mapped to the node immediately following the leaving node in the ring. Again, expected number of keys moved around is M/N. 

                                       Image source: https://www.toptal.com/big-data/consistent-hashing

Implementation

I am also currently working with Python for some of our deep learning projects at work (I usually code in Java at work), so it made sense to choose Python as the language of choice for writing up on consistent hashing!

The key ingredient in implementing the consistent hashing algorithm is using an efficient data structure to quickly look up the number that is equal to or greater than the key's hash value. One such data structure is the binary search tree (BST). A BST can store all the node IDs of existing nodes. When a key needs to be mapped to a node, we simply hash the key and look up the node ID nearest to the hash value of the key.

The binary search tree we've used needed a little modification from the standard implementation to make it act like a sorted ring of values. When looking up a value, if we reach the end of the BST and no target node was found we needed to wrap around and return the first node.

The Python code implementing consistent hashing with the binary search tree (BST) from scratch can be found on my GitHub account here. The GitHub repo also contains test code to test the implementation.

Improved load balancing

The basic implementation can be improved further to make the key-to-node mapping more balanced by creating virtual nodes for each node and place them randomly throughout the ring. In this setup, each node will have K replicas placed around the king. This has the effect of increasing the probability of hitting the nodes more evenly. With fewer nodes the "gaps" between the nodes are wider and may lead to some nodes receiving more keys mapped to them than the others. With replicas of nodes we effectively reduce the gap sizes and increase the probability of hitting the nodes higher in a more uniform manner.

Tuesday, May 21, 2013

A C++ Thread Pool Implementation Using POSIX Threads

Threads are very useful but potentially very tricky constructs in computer programming. They are also generally hard to get right. Building a deadlock-free and correct multi-threaded software requires great care.

And then threads are expensive. The processor needs to do a context switch to jump between threads. Each context switch means saving the state of the current executing thread and then loading the thread selected for execution. Creating a thread for every I/O operation or lengthy computation can push the machine to a halt if the number of requests for I/O and/or computation is too high.

The midway between creating too many threads for all requests for service and doing everything in one thread is to create a pool of threads and reuse a thread as soon as it is done servicing a request. C++ does not have a built-in thread pool library (it even didn't have threading support prior to C++ 11 standard that came out last year). Java has the Executor interface for this. With C++, some people use the Boost threading library or the Boost Asio library for achieving performance gains in their applications.

In this article we will design a very simple thread pool library using C++ and POSIX threads (also known as pthreads). We will not use any other external libraries like Boost.

The core idea of a thread pool is to have a number of worker threads always ready for accepting work from the main process. The process receives requests for work and schedule those requests as doable tasks for the threads. This pattern resembles the well known Readers-writers problem in Computer Science. There is a queue that is populated with tasks as they arrive in the process. The request processing part is the writer and the threads are the readers. Request processor will insert the work items as they arrive and the threads will pick up one item at a time from the queue in a First-In-First-Out fashion.

So a thread pool basically consists of the three primary actors:

1. A number of threads either waiting for or executing tasks.
2. A number of service requests to the server from clients. Each request is considered a task to the threads.
3. A queue holding the incoming service requests.




*image source: http://www.javamex.com/tutorials/threads/thread_pools.shtml

My C++ implementation of a thread pool:

In an attempt to dig deep into the multi-threading and asynchronous world of computation I developed a very simple thread pool using C++ and pthreads.

The source code for the thread pool with example usage can be found on my github page: https://github.com/bilash/threadpool

Each request is represented by an object of the class Task. Class Task essentially wraps a function pointer which points to the actual function that constitutes the task. Since it's a class we also provide a functor based interface to invoke the function pointer stored in the Task class.

The actual thread pool is manged by the ThreadPool class. It stores a vector of threads, a queue to store tasks (instances of the class Task), a method to enqueue incoming tasks, and a method to execute the tasks in a thread. There are also helper methods to initialize and destroy (and clean up) the pool.

Thread synchronization:

When developing a multi-threaded service or application we almost always need to use locks to prevent data corruption and data races in our program. This essentially means data that will be accessed by more than one thread - also known as a critical section - need to be protected by some kind of locking mechanism. The most popular of these locking mechanisms is called a mutex (for mutual exclusion). A mutex achieves what its name suggests - it allows execution of the code in a mutually exclusive way!

In our thread pool program we use mutex to protect our shared resource - the queue holding the tasks. The task queue is populated by the request processing part of the program and the is read by the threads waiting to pick up tasks. The threads also remove the task they have picked from the task queue since the task will no longer be needed after it is executed. Since the queue is accessed (and modified) by multiple actors in the program it is protected by a mutex lock.

The other synchronization mechanism used in the program is called Condition Variables. Condition variables are used to signal threads waiting on a condition to be true. In our program we use it to signal the threads that the queue has been populated with new tasks. The threads wait (put to sleep by the OS) while the task queue is empty. We wake up the waiting threads by using a condition variable.

Feel free to browse through the code and let me know if you find a bug or some issues with it. Again, it's a very minimalist code just to demonstrate the concept of thread pooling, so don't expect it to be very robust and flexible!

Thanks for reading!

Friday, April 5, 2013

Waiting for a child process with a timeout on Linux

Recently at work we were developing a backend server for a Web app. The server process creates a child process for each request that arrives at it. The server then waits for the child process to terminate. But since we couldn't wait indefinitely we needed a wait() function with a timeout. Since Linux does not have such a function in its wait() family of system calls we created a wrapper around the existing system call waitpid() that takes an additional boolean parameter which is set to true or false depending on whether the wrapper function is returning because of a timeout or not.

It looks something like this:

pid_t waitpid_with_timeout(pid_t pid, int *status, int options, int timeout_period, boolean* timed_out);

The body of the function essentially does this:

1. Set an signal handler for SIGALRM which doesn't do anything (we just need to know that alarm went off) and mask all other signals.
2. Install the signal sigaction structure.
3. Set the alarm clock by calling the alarm() system call.
4. Call the Linux system call waitpid().
5. If waitpid() returned -1 and errno was set to EINTR this means our alarm went off and we set timed_out to true. Otherwise if waitpid() succeeded then we did not timeout and the child process terminated before the timeout period specified in the parameter timeout_period.

After waitpid_with_timeout() returned, we check the timed_out parameter. If timed_out is set to true we kill the child process explicitly:

kill(pid, 9);

Now, everything was all good and dandy with this implementation. Until during testing we found out that even though was called waitpid() in the function waitpid_with_timeout() we did not collect the exit status of the child in the case of a timeout (when we explicitly killed the child with kill()). This was the backend of a Web application, so uncollected children were piling up with each request from the browser and they were all becoming zombie processes!

We realized that the solution to this problem was simply another call to waitpid() when the child was explicitly killed with kill(). So when waitpid_with_timeout() returned timed_out == true we simply added another call to waitpid() after we call kill():

kill(pid, 9);
waitpid(pid, &status, 0);

This solved our zombie process problem!

There are some interesting discussion of this topic on Stack Overflow if you are interested: http://stackoverflow.com/questions/282176/waitpid-equivalent-with-timeout

Sunday, September 9, 2012

Generating all permutations, combinations, and power set of a string (or set of numbers)

Combinatorics is a branch of mathematics that deal with counting of discrete structures. Two concepts that often come up in the study of combinatorics are permutaions and combinations of a set of discrete elements.

The number of possible permutations (and combinations) that can be generated from a set of discrete elements can be huge. We will leave the mathematical analysis here and instead focus on how to generate all permutations and combinations of a set of numbers/characters by using a computer program. We will be using the Java programming language to write the code for this.

A related concept is the concept of Power Set. A Power Set of a set of discrete elements is the set of all subsets of elements from the original elements. Essentially this is a set of all combinations of all lengths (0 to size of the set) of the set.

Generating all permutations of a string (of characters):

The key to understanding how we can generate all permutations of a given string is to imagine the string (which is essentially a set of characters) as a complete graph where the nodes are the characters of the string. This basically reduces the permutations generating problem into a graph traversal problem: given a complete graph, visit all nodes of the graph without visiting any node twice. How many different ways are there to traverse such a graph?

It turns out, each different way of traversing this graph is one permutation of the characters in the given string!

We can use Depth First Search (DFS) traversal technique to traverse this graph of characters. The important thing to keep in mind is that we must not visit a node twice in any "branch" of the depth-first tree that runs down from a node at the top of the tree to the leaf which denotes the last node in the current "branch".

              START
           /       |         \
         A       B        C
       /   \      /  \      /   \
     B    C   A  C  A   B
      |      |    |     |   |      |
     C    B  C   A  B    A

In the above figure, a "branch" is the vertical line that connects all 3 characters. A "branch" is one permutation of the given string. In a recursive (DFS-based) solution the trick is to maintain an array that holds one such "branch" at any given time.

Here is the Java code:

void generatePermutations(char[] arr, int size, char[] branch, int level, boolean[] visited)
{
    if (level >= size-1)
    {
        System.out.println(branch);
        return;
    }
   
    for (int i = 0; i < size; i++)
    {
        if (!visited[i])
        {
            branch[++level] = arr[i];
            visited[i] = true;
            generatePermutations(arr, size, branch, level, visited);
            visited[i] = false;
            level--;
        }
    }
}

The above method can be called like this:

String str = "ABCD";
int n = str.length();
char[] arr = str.toCharArray();
boolean[] visited = new boolean[n];
for (int i = 0; i < n; i++)
    visited[i] = false;
char[] branch = new char[n];
generatePermutations(arr, n, branch, -1, visited);

The visited array keeps track of which nodes have been visited already.

Generating combinations of k elements:
Generating combinations of k elements from the given set follows similar algorithm used to generate all permutations, but since we don't want to repeat an a character even in a different order we have to force the recursive calls to not to follow the branches that repeat a set of characters.

If the given string is "ABC" and k = 2, our recursive tree will look like this:

           START
           /        | 
         A        B
       /     \      |
     B      C   C

Here we will have to make sure, once we start a "branch" from a node (character), we must not come back to that node (character) again to start another "branch". So, starting off a new recursive call (to traverse a new "branch") must start from the following node (character)!

Here is the Java code for generating k combinations:

void combine(char[] arr, int k, int startId, char[] branch, int numElem)
{
    if (numElem == k)
    {
        System.out.println(Arrays.toString(branch));
        return;
    }
   
    for (int i = startId; i < arr.length; ++i)
    {
        branch[numElem++] = arr[i];
        combine(arr, k, ++startId, branch, numElem);
        --numElem;
    }
}

In the above code, that variable startId makes sure we are never starting a new recursive call for a new "branch". It gets incremented for a new traversal.

To call the combine method above, do this:

int k = 2;
char[] input = "ABCD".toCharArray();
char[] branch = new char[k];
combine(input, k, 0, branch, 0);

Generating the power set:

Generating all subsets (the power set) of a given set of characters (or numbers) is very similar to generating combinations. While generating k-element combinations our goal was to print the current "branch" only when it holds all k characters from the given string. Since a power set contains combination of all lengths, we will simply call combine to generate k combinations for all k where 0 <= k < SIZE(string).

void powerSet(char[] arr)
{
    for (int i = 0; i < arr.length; ++i)
    {
        char[] branch = new char[i];
        combine(arr, i, 0, branch, 0);
    }   
}

To call the powerSet method, simply pass in the character array we want to construct the power set of.

I hope you now have a good idea of how to generate all permutations, k-element combinations, and the power set of a given set of elements. I used to find these really hard in the beginning. But once I have started thinking these in terms of graph traversal problems, things became much easier!

Graph algorithms rock :)

Monday, February 20, 2012

Life on a Tree - Creating, Copying, and Pruning Tree Data Structures

The tree is one of the most important data structures in computer programming.



A very common of example of the tree data structure being used in everyday life is a directory structure widely used both in Windows and Unix systems.

Today I will write a brief tutorial on trees, and show a number of common operations done on the structure of a tree. The code examples will be in Java.

First, let's create a tree. A tree is essentially a set of node objects that hold some data and each node has zero or more "child" nodes. The nodes without any child nodes are called leaves and the node that itself is not a child of any other node is called the root. You can read up on trees on the Web, we will focus on the implementation of trees here.

We can define a node in the tree as follows:

class Node
{
    private int data; /// this can be more complex objects
  
    private Node left;
    private Node right;
    private Node parent;
  
    public Node(int value)
    {
        data = value; // we could also use setter/getter for this value
        parent = null; // this is optional, but having this is often useful
        left = null;
        right = null;
    }
   
    public Node(Node node) // copy constructor
    {
        data = node.getData();
        left = null;
        right = null;
        parent = null;
    }

    public void setData(int data)
    {
        this.data = data;
    }
   
    public int getData()
    {
        return data;
    }
   
    public void setLeftNode(Node node)
    {
       left = node;
    }

    public Node getLeftNode()
    {
        return left;
    }

    public void setRightNode(Node node)
    {
        right = node;
    }

    public Node getRightNode()
    {
       return right;
    }

    public void setParentNode(Node node)
    {
         parent = node;
    }

    public Node getParentNode()
    {
        return parent;
    }
}

The class above is a simple example of a node in a tree. It contains a data variable, a reference to left and right sub-trees each, and a reference to the parent node as well. There are also a set of setters and getters to manipulate the data members. Please note that - in practice you might need much more complex nodes loaded with big data objects and more children than just a left and right nodes.

To create a tree using the class above you could do the following:

Node root = new Node(100); // creating a root node with value 100

root.setLeftNode(new Node(200));
root.setRightNode(new Node(50));

root.getLeftNode().setLeftNode(new Node(10));
root.getLeftNode().setRightNode(new Node(1000));

A tree is represented by the root node of the tree. So, the "root" object above represent the entire tree.

Copying a tree:

Say you have a tree that you want to copy to another tree. Now, copying can be done in two ways - i) shallow copy, and ii) deep copy.

We will show deep copy here. Deep copying means the new tree is an entirely new copy of the old tree - both in structure and in data. Everything is allocated again in memory.

The following function does a deep copy of a tree to another tree:

public void deepCopy(Node root, Node copy)
{
        Node left = root.getLeftNode();
        Node right = root.getRightNode();
       
        if (left != null)
        {
            copy.setLeftNode(new Node(left));
            deepCopy(left, copy.getLeftNode());
        }
       
        if (right != null)
        {
            copy.setRightNode(new Node(right));
            deepCopy(right, copy.getRightNode());
        }
}

This function can be called like this:

Node copy = new Node(root); // copy the root
deepCopy(root, copy); // copy the rest of the tree

Now, the tree referenced by "copy" holds an entire deep copy of the tree "root".

Pruning a tree:

Pruning means deleting one or more subtrees of a tree. We will implement a filter-based pruning here. That is, whenever a node will match some criteria described in a filter we will delete that node along with all its children from its parent.

First, we will need a way to represent a filter. We will do that by way of an interface that all filter classes will have to implement.

interface Filter
{
    public boolean check(Node node);
}

Now, we can define a Filter class as follows:

class MyFilter implements Filter
{
    public boolean check(Node node)
    {
        if (node.getData() == 200)
            return false;
        return true;
    }
}

This class indicates that we would like to delete all sub-trees rooted at a node containing the data value 200.

The pruning class will be a lot like the deep copying class. This is because the pruned tree is actually a deep copy of the original tree minus the pruned nodes!

Here is what the pruning method looks like:

public void pruneTree(Node root, Node copy, Filter filter)
{
        Node left = root.getLeftNode();
        if (left != null && filter.check(left))
        {
            copy.setLeftNode(new Node(left));
            pruneTree(left, copy.getLeftNode(), filter);
        }
       
        Node right = root.getRightNode();
        if (right != null && filter.check(right))
        {
            copy.setRightNode(new Node(right));
            pruneTree(right, copy.getRightNode(), filter);
        }
}

This method can be called as follows:

Node copy = new Node(root);
Filter filter = new MyFilter();
pruneTree(root, copy, filter);

At this point, the node object "copy" will contain a pruned version of the original tree rooted at "root".

Finally, here is a small class that can print a tree in pre-order traversal. I used this class in testing out the code in this post:

public void printTree(Node root)
{
        if (root != null)
        {
            System.out.println(root.getData()); 
            printTree(root.getLeftNode());
            printTree(root.getRightNode());
        }
}

I hope you will now be able to create simple trees quickly. It won't be hard to modify these classes and methods to build more complex tree classes to handle more than two children, copying a subtree instead of the whole tree, prune more selectively, etc.

Happy tree coding:)

Friday, February 10, 2012

Solving the Boggle Game - Recursion, Prefix Tree, and Dynamic Programming

I spent this past weekend designing the game of Boggle. Although it looks like a simple game at a high level, implementing it in a programming language was a great experience. I had to use recursion, sorting, searching, prefix trees (also knows as trie's), and dynamic programming as I was improving the run time of the program. So here is a summary of the work I did in trying to write a very optimal solution to the Boggle game:



*Image courtesy: http://stackoverflow.com/questions/746082/how-to-find-list-of-possible-words-from-a-letter-matrix-boggle-solver

The game of Boggle is played on a N x N board (usually made of cubes that has letters engraved on it). Given a dictionary, you will have to construct words on the board following these rules:

i) The letters in the word must be "adjacent" to each other
ii) Two letters on the board are "adjacent" if they are located on the board next to each other horizontally, vertically, or diagonally.
iii) A word must be 3 or more letters long to be valid

The more words you can construct, the more points you get. Longer words get more points.

In computerse: given an N x N board, and a dictionary containing hundreds of thousands of words, construct those words from the board that are found in the dictionary.

There are a number of approaches that one can take to solve this problem. I designed three different algorithms to solve it. Here are they:

First solution - Recursion + Binary Search:

In this approach, we recurse (using backtracking) through the board and generate all possible words. For each word that are three or more letters long we check to see if it's in the dictionary. If it is, we have a match!

Here is the algorithm steps:

1. Read in the dictionary in to a container in memory.
2. Sort the dictionary.
3. Using backtracking, search the board for words.
4. If a word is found and it contains 3 or more letters, do a binary search on the dictionary to see if the word is there.
5. If searching was successful in the previous step, print the letter out.
6. Continue step 3-5 as long as there are more words to construct on the board.

Complexity of this approach:

In this solution, we do a good job on the dictionary search. Using binary search, we are quickly finding out whether a word is in dictionary or not. But the real bottleneck is in searching the board for words. For an N x N board the search space is O((N*N)!). I am not exactly sure about this number, but you can find some discussions of it here: http://www.danvk.org/wp/2007-08-02/how-many-boggle-boards-are-there/.
(N*N)! is a HUGE number even for N = 5. So this approach is is impractical and out of question for any useful implementation.


Second Solution - Pruned Recursion + Prefix Tree (also known as a Trie):

From the previous approach, our major concern was the enormous search space on the board. Fortunately, using a a prefix tree or trie data structure we could significantly cut down on this search space. The reasoning behind this improvement is that, if a word "abc" does not occur as a prefix to any word in the dictionary there is no reason to keep searching for words after we encounter "abc" on the board. This actually cut down the run time a lot.

Here is the algorithm steps:

1. Read a word from the dictionary file.
2. Insert it into a prefix tree data structure in memory.
3. Repeat steps 1-2 until all words in the dictionary have been inserted into the prefix tree.
4. Using backtracking, search the board for words.
5. If a word is found and it contains 3 or more letters, search the prefix tree for the word.
6. If searching was *not* successful in the previous step, return from this branch of the backtracking stage. (There is no point to continue searching in this branch, nothing in the dictionary as the prefix tree says).
7. If searching was successful in step 5, continue searching by constructing more words along this branch of backtracking and stop when the leaf node has been reached in the prefix tree. (at that point there is nothing more to search).
8. Repeat steps 4-7 as long as there are more words to search in the backtracking.

Complexity of this approach:

This approach significantly improves on the first one. Building a prefix tree our of the dictionary words is O(W * L), where W is the number of words in the dictionary and L is the maximum length of a word in the dictionary.
Searching the board will be of the same order as the dictionary since we are not really searching words that are not in the dictionary. But in reality it will be more work than that as we still need to backtrack along the board to construct new words until we can consult the dictionary prefix tree to know whether it exists or not.

Third and Final Solution - No search space + Dynamic Programming:

The 2nd approach mentioned above was good enough until the board size was 5. Unfortunately with a board size of 6, that too was taking forever to complete!



It got me into thinking - "Dang, this search space is still too big to search! Can I just get rid of it entirely?" And then this idea popped into my mind: instead of random constructing word after word in this infinite ocean of words why don't I take a word from the dictionary and somehow magically check whether that's available on the board or not?

It turns out, we can use a nifty dynamic programming technique to quickly check whether a word (from the dictionary in this case) can be constructed from the board or not!

Here is core point of the dynamic programming idea:


For a word of length k to be found (end location) at the [i, j]-th location of the board, the k-1'th letter of that word must be located in one of the adjacent cells of [i, j].



The base case is k = 1.

A letter of length 1 will be found (end location) in the [i, j]-th cell of the board of the only letter in the word matches the letter in the [i, j]-th location of the board.

Once our dynamic programming table is populated with the base case, we can build on top of that for any word of length k, k > 1.

Here is a sample code for this:

for (k = 2; k < MAX_WORD_LENGTH; ++k)
    for (i = 0; i < N; ++i)
        for (j = 0; j < N; ++j)
            if (board[i][j] == word[k])
            {
                 for all the "adjacent" cells of board[i, j]
                     if we table[k-1][adjacent_i][adjacent_j] is true
                         then table[k][i][j] = true;
             }


*I have the C++ code, if you are interested to take a look leave a comment with your email address.

Run Time Complexity:

The run time complexity of this approach is obvious is pretty obvious. It's O (W * N * N * MAX_WORD_LENGTH). Here N is dimension of the board which is usually between 4 to 6. So essentially the algorithm runs in the order of the size of the dictionary!

I solved a 6 X 6 boggle game with a dictionary with 600,000 words in 0.002 seconds! With I/O it as about 0.57 seconds. But with the trie and binary search approaches, it took much longer!

So, here is the summary of my a weekend's worth of adventure in to the world of algorithms and data structures. Feel free to ask me any questions, or any bug in these algorithms:)

Credit where it is due: My friend Satej who is a ninja problem solver and a PhD student at the UCF explained the dynamic programming solution to me first before I refined and finalized it. Thanks Satej!


Update (5/10/2012):

So many people have asked for the source code that I feel I better add the C++ code here:) The final, polished code for this is not in my home computer here, so if there is a bug please let me know!


#include <cstdio>
#include <iostream>

using namespace std;

const int N = 6; // max length of a word in the board

char in[N * N + 1]; // max length of a word
char board[N+1][N+2]; // keep room for a newline and null char at the end
char prev[N * N + 1];
bool dp[N * N + 1][N][N];

// direction X-Y delta pairs for adjacent cells
int dx[] = {0, 1, 1, 1, 0, -1, -1, -1};
int dy[] = {1, 1, 0, -1, -1, -1, 0, 1};
bool visited[N][N];

bool checkBoard(char* word, int curIndex, int r, int c, int wordLen)
{
    if (curIndex == wordLen - 1)
    {
        //cout << "Returned TRUE!!" << endl;
        return true;
    }
   
    int ret = false;
       
    for (int i = 0; i < 8; ++i)
    {
        int newR = r + dx[i];
        int newC = c + dy[i];
       
        if (newR >= 0 && newR < N && newC >= 0 && newC < N && !visited[newR][newC] && word[curIndex+1] == board[newR][newC])
        {
            ++curIndex;
            visited[newR][newC] = true;
           
            ret = checkBoard(word, curIndex, newR, newC, wordLen);
            if (ret)
                break;
               
            --curIndex;
            visited[newR][newC] = false;
        }
    }
   
    return ret;           
}

int main(int argc, char* argv[])
{
   
    int i, j, k, l;
   
    FILE* fDict = fopen("dict.txt","r");
    FILE* fBoard = fopen("tmp2.txt","r");
   
    for(i = 0; i < N; ++i)
        fgets(board[i], N+2, fBoard);
   
    strcpy(prev,"");
    int pLen = 0;
   
    while(fgets(in, N*N + 1, fDict))
    {
        int len = strlen(in);
        if (in[len-1] == '\n')
        {
            in[len-1] = '\0'; // remove the trailing newline
            --len;
        }
       
        if (len < 3)
            continue; // we only want words longer than 3 or more letter
           
        for(i = 0; i < len && i < pLen; ++i)
        {
            if(prev[i] != in[i])
                break;
        }
       
        int firstMismatch = i; // little optimization: building on previous word (will benefit if the word list is sorted)
       
        if(firstMismatch==0)
        {
            for(i = 0; i < 6; ++i)
            {
                for(j = 0; j < 6; ++j)
                {
                    if(board[i][j] == in[0])
                        dp[0][i][j] = true;
                    else
                        dp[0][i][j] = false;
                }
            }
            firstMismatch = 1;
        }
       
        for(k = firstMismatch; k < len; ++k)
        {
            for(i=0;i<6;++i)
            {
                for(j=0;j<6;++j)
                {   
                    dp[k][i][j] = false;
                           
                    if(board[i][j] != in[k])
                        continue;
                       
                    for(l= 0; l < 8 && !dp[k][i][j]; ++l)
                    {
                        int ti = i + dx[l];
                        int tj = j + dy[l];
                       
                        if(ti < 0 || ti >= 6 || tj < 0 || tj >= 6)
                            continue;
                       
                        if (dp[k-1][ti][tj])
                            dp[k][i][j] = true;
                    }
                }
            }  
        }
       
        // check if the word is tagged as found in the dp table
        bool flag = false;
        for(i = 0; i < 6 && !flag; ++i)
        {
            for(j = 0; j < 6 && !flag; ++j)
            {
                if(dp[len-1][i][j])
                    flag =true;
            }
        }
       
        // dp table says its there, but make sure its in the board and it does not repeat a location in the board
        if(flag)
        {
            //cout << "Checking word: " << in << endl;
            bool verified = false;
           
            for (i = 0; i < N && !verified; ++i)
            {
                for (j = 0; j < N && !verified; ++j)
                {
                    if (in[0] != board[i][j])
                        continue;
                       
                    memset(visited, false, sizeof(visited));
                    visited[i][j] = true;
                   
                    if (checkBoard(in, 0, i, j, len))
                    {
                        cout << in << endl;
                        break;
                    }
                }
            }
        }
       
        strcpy(prev,in);
        pLen=len;
           
    }
   
    return 0;
}


Enjoy!

Update 3/15/2015: I have added a Java version of the boggle game solution on my Github page here: https://github.com/bilash/boggle-solver

The Java version builds a Trie for the dictionary and then uses the dynamic programming approach mentioned above.

Friday, January 20, 2012

The Power of Perl Regular Expressions

I am not a big scripting language person. C++ and Java (and occasional C#) have been my stronghold so far. But recently I have been strolling around in the Perl world at work. And to my amazement I found Perl so rich a language that I fell in love with it.

Here is an example of the power and succinctness of Perl. This one involves using Perl's powerful regular expressions and file/directory handling features.

For a project I was doing I needed to modify contents of a large number of files in a directory. The actual modification needed was really simple. For all occurrences of a term, say XYZ, replace it with .XYZ (insert a dot before the term). The term XYZ could or could not be preceded by a "-" (minus sign). Other than that the term is guaranteed to be separated by white space before and after it. Also, the term can occur at the beginning or at the end of a line.

Now, I was new to Perl, and regular expressions for that matter. I struggled with the problem for one and half day. There were always one or two cases that I was not covering, or cases I was not covering that I was not supposed to. Finally I gave up and posted this as a question on Stackoverflow:

Overlapping text substitution with Perl regular expression.

Within 10 minutes I had a bunch of answers from some really Perl expert guys. I am so thankful to those people who answered it and saved me from more days of struggling with the problem!

Here is what I ended with for the whole problem:


$^I = ".copy";
my @ARGV = <*.txt>;

while (<>)
{
    s/((?:^|\s)-?)(XYZ)(?=\s|$)/$1.$2/g;
    print;
}


That's it! How cool is that?

If it was in C++ or Java, I would have to deal with opening file streams, buffer streams in Java, open a file one at a time, read a line one by one, store it on a list or something in memory, and then write the whole thing back to file system, so on and so forth. And I would easily end up writing close to a hundred lines of code for that!

But in Perl, it's just those few lines above. The diamond operator (<>) nicely takes care of the directory and file browsing part so succinctly. And the one line regular expression finds all references of the pattern in all the files (denoted by the wildcard *.txt) and replaces the matching patterns with a dot inserted in the beginning.

I admit, Perl's regex and file/directory handling packages are working in the background, but still the expressive power of these tools are so elegant!

I hope to keep exploring Perl more in the coming days. Specially the powerful regular expressions!

Wednesday, August 17, 2011

A Binary Search Trees Tutorial

This is a short tutorial on Binary Search Trees covering the basic operations that are often done on such trees.

A Binary Search Tree (BST) is a tree data structure that has the following property:

For any node x in the tree, the left subtree rooted at x only contains nodes with values less than or equal to the value stored at x; and the right subtree rooted at x only contains nodes with values greater than the value at x. Also, the left subtree and the right subtree of x must be binary trees themselves.


                              Figure: A binary search tree of size 9 and depth 3, with root 8
                                         and leaves 1, 4, 7 and 13
                                        Image source: Wikipedia


BSTs have sub-linear (logarithmic) average case complexity for element insertion and searching. Once a BST is built, sorting it can be done in linear time. All we need is traverse the tree in inorder.

There are a number of operations that can be done on a BST. We will describe those here in some details with Java code examples.

First let's see how we can represent a tree node in Java:

class TreeNode
{
    int data;
    TreeNode parent;
    TreeNode left;
    TreeNode right;
   
    public TreeNode(int data)
    {
        this.data = data;
        parent = left = right = null;
    }
   
    public void setData(int data)
    {
        this.data = data;
    }
   
    public void setLeft(TreeNode node)
    {
        left = node;
    }
   
    public void setRight(TreeNode node)
    {
        right = node;
    }
   
    public void setParent(TreeNode node)
    {
        parent = node;
    }
   
    public int getData()
    {
        return data;
    }
   
    public TreeNode getLeft()
    {
        return left;
    }
   
    public TreeNode getRight()
    {
        return right;
    }
   
    public TreeNode getParent()
    {
        return parent;
    }
}

Inserting a node into a BST:

Inserting a node into a BST is pretty straightforward. We start checking nodes starting from the root node. At each node if the value is greater than the value in the node to be inserted then we move to the left child, otherwise we move to the right child. We repeat this until there is no more node to be traversed. The last node will be the parent node of the node to be inserted. Now if the node to be inserted has a smaller value than the current node then we set it as the left child of the current node otherwise we set it as the right child.

Java code:

public static void insert(TreeNode node)
{
    TreeNode x, y;
          // root is a global variable and is the root of the tree
    x = y = root; // Assume root is initialized to null
   
    while (x != null)
    {
        if (x.getData() > node.getData())
        {
            y = x;
            x = x.getLeft();
        }
        else
        {
            y = x;
            x = x.getRight();
        }
    }
   
    // y will be the parent of node
    if (y == null)
    {
        root = node;
        return;
    }
   
    if (y.getData() > node.getData())
        y.setLeft(node);
    else
        y.setRight(node);
   
    node.setParent(y);
}

Deleting a node from a BST:

When deleting a node from a BST, there can be 3 different scenarios:
i) The node does not have any children: This is kind of a trivial case. When there is no children all we need is removing the node from its parent.
ii) The node has only one child: This is also a relatively simple case. Since the node has only one child, we will simply replace the node with this lone child without violating the BST sorted order property.
iii) The node has two children: Now, this is a bit tricky :) Since there are two children we have to be careful not to break the sorted order of the BST. To keep the sorted order intact, we need to replace the node to be deleted with its successor so that the order is maintained in the absence of the node. Once a successor is found, we need to replace the node to be deleted with it. It turns out the successor can have at most one child. So if the successor has a child we will have to replace the successor with that child after the successor already replaced the node to be deleted.

Java code:

public static void delete(TreeNode node)
{
    // Case 1: node does not have a child, just delete it
    if (node.getLeft() == null && node.getRight() == null)
    {
        if (node.getParent() != null && node.getParent().getLeft() == node)
            node.setLeft(null);
        else if (node.getParent() != null && node.getParent().getRight() == node)
            node.setRight(null);           
    }
    // Case 2: node has only one child, splice the child with its parent
    else if (node.getLeft() == null || node.getRight() == null)
    {
        if (node.getParent() != null)
        {
            TreeNode x = node.getLeft() == null ? node.getRight() : node.getLeft();
            if (node.getParent().getLeft() == node)
                node.getParent().setLeft(x);
            else
                node.getParent().setRight(x);
        }           
    }
    // Case 3: node has both children, set the successor of the node to its parent
    else
    {
        TreeNode x = findSuccessor(node); // x will have at most one child
        // Instead of deleting we can just copy the successor's data over to the node to be deleted
        node.setData(x.getData());
        // Now delete the successor and set its child (if any) to its parent
        TreeNode nodeChild = x.getLeft() == null ? x.getRight() : x.getLeft();
        if (x.getLeft() != null)
        {
            if (x.getParent().getLeft() == x)
                x.getParent().setLeft(nodeChild);
            else
                x.getParent().setRight(nodeChild);
        }
        else
        {
            if (x.getParent().getLeft() == x)
                x.getParent().setLeft(nodeChild);
            else
                x.getParent().setRight(nodeChild);
        }
    }
}

Finding the minimum value in the tree:

In a BST, all nodes to the left of a node contains smaller (or equal) values than the value stored in the node itself. This gives us a clue about how to tackle the problem of finding the minimum value stored in the BST. If you think about it, the node with the minimum value in a BST is actually the leftmost node in the tree. If it wasn't then there would be node(s) in the tree that are on the right of some some node but contain smaller values than the node itself, thus violating the BST contract.

Here is the Java code for finding the minimum value node in a BST:

public TreeNode findMinimum(TreeNode root)
{
    if (root == null)
        return null;
  
    if (root.getLeft() != null)
        return findMinimum(root.getLeft());
  
    return root;
}

Finding the maximum value in the tree:

Similar to the logic as in finding the minimum value,  the maximum value node will be the right most node in the tree. The Java code for finding the maximum value node in a BST:

public static TreeNode findMaximum(TreeNode root)
{
    if (root == null)
        return null;
  
    if (root.getRight() != null)
        return findMaximum(root.getRight());
  
    return root;
}

One useful property of the BST is that if it is traversed in inorder we get a sorted list of values stored in the tree nodes. There are two operations that are often done in relation to maintaining this sorted order of a BST: finding the successor node and predecessor node of a given node.

Finding the successor node of a given node:

The successor node is the node with the next bigger number in the tree. Since all numbers to the left are smaller than the current node the successor node has to be located in the right subtree. Now, since it is the next bigger number of the current node, it has to be the smallest of all numbers in the right subtree. So, essentially we are looking for the minimum number in the right subtree of the current node. In the case that there is no right child of the node, we need to look up the tree and figure out the successor. The successor is the first ancestor whose left subtree has this node as the largest number. In other words: the first ancestor of this node whose left child is also an ancestor of this node. The intuition is: as we traverse left up the tree we traverse smaller values, the first node on the right is the next larger number.

Here is the Java code:

public static TreeNode findSuccessor(TreeNode node)
{
    if (node == null)
        return null;
   
    if (node.getRight() != null)
        return findMinimum(node.getRight());
  
    TreeNode y = node.getParent();
    TreeNode x = node;
    while (y != null && x == y.getRight())
    {
        x = y;
        y = y.getParent();
    }

    return y;
}

Finding the predecessor node of a given node:

The predecessor of a given node is the node containing the next smaller value. Finding the predecessor follows symmetric rules that we used for finding the successor.

Java code for finding the predecessor:

public static TreeNode findPredecessor(TreeNode node)
{
    if (node == null)
        return null;
   
    if (node.getLeft() != null)
        return findMaximum(node.getLeft());
  
    TreeNode y = node.getParent();
    TreeNode x = node;
    while (y != null && x == y.getLeft())
    {
        x = y;
        y = y.getParent();
    }

    return y;
}

Determining whether a tree is a BST or not:

Sometimes we already have a binary tree that we need to determine whether it is a BST or not. This is an interesting problem and can be really solved with a simple recursive solution.

The BST property - that every node on the right subtree has to be larger than the current node and every node on the left subtree has to be smaller (or equal) than the current node - is the key to figuring out whether a tree is a BST or not. On a first thought it might look like we can simply traverse the tree and at every node check whether the node contains a value larger than the value at the left child and smaller than the value on the right child, and if this condition holds for all the nodes in the tree then we have a BST. This is the so called Greedy approach, making a decision based on local properties. But this approach clearly won't work for the following tree:

      20
     /    \
  10    30
          /   \
        5     40

In the tree above, at every node the condition that the node contains a value larger than its left child and smaller than its right child hold, still its not a BST: the value 5 is on the right subtree of the node containing 20, a violation of the BST property!

So how do we solve this? It turns out that instead of making a decision based solely on a node and its children's values, we also need information flowing down from the parent as well. In the case of the tree above, if we could remember about the node containing the value 20 we could see that the node with value 5 is violating the BST property contract.

So the condition we need to check at each node is that: a) if the node is the left child of its parent, then it must be smaller (or equal) than the parent and it must pass down the value from its parent to its right subtree to make sure none of the nodes in that subtree is greater the parent, and similarly b) if the node is the right child of its parent, then it must be larger than the parent and it must pass down the value from its parent to its left subtree to make sure none of the nodes in that subtree is greater the parent.

A simple but elegant recursive solution in Java can explain this further:

public static boolean isBST(TreeNode node, int leftData, int rightData)
{
    if (node == null)
        return true;
   
    if (node.getData() > leftData || node.getData() <= rightData)
        return false;
   
    return (isBST(node.left, node.getData(), rightData) && isBST(node.right, leftData, node.getData()));
}

The initial call to this function can be something like this:

if (isBST(root, Integer.MAX_VALUE, Integer.MIN_VALUE))
    System.out.println("This is a BST.");
else
    System.out.println("This is NOT a BST!");

Essentially we keep creating a valid range (starting from [ MIN_VALUE, MAX_VALUE]) and keep shrinking it down foe each node as we go down recursively.

So, here is my short and sweet primer on Binary Search Trees. Hope you found it useful. Let me know if there is a bug or a possible improvement in runtime or space.