Flattening a Binary Tree

In which we design and implement an algorithm that given a binary tree, it flattens it in-place.

For more info, see LeetCode.

Understanding the problem

The problem asks us to design an algorithm that takes a binary tree as input, and flattens the input tree such that every node with the exception of last node only has a right node and last node has no child nodes at all.

Also, while not explicitly stated in the problem description, the nodes of the resulting flattened tree should have a specific order which is illustrated by an example. In particular, given the following tree:

    1
   / \
  2   5
 / \   \
3   4   6

The flattened tree should look like:

1
 \
  2
   \
    3
     \
      4
       \
        5
         \
          6

Finally, our algorithm should flatten the input tree in place: i.e. it should not rely on auxiliary data structures or create new tree nodes.

Observation

If we pay attention to the resulting tree, its nodes have the same order as what a pre-order traversal of the tree would produce:

TreeTraversals.newPreOrder(false /* visit null nodes? */)
              .traverse(root, n -> System.out.printf("%s ", n.val))
// Outputs: 1, 2, 3, 4, 5, 6

This observation directs us towards designing an algorithm:

0- If the root node of our tree is null, do nothing and return
1- Create an empty list
2- Traverse the tree in pre-order
  2.1- Each time a node is visited, add it to the list
3- Loop over the list
  3.1- Set the right-node of the current node to the next node
     (if the current node is the last node, set its right node to null)
  3.2- Set the left-node of the current node to null

Before implementing the algorithm, first let’s write a utility class named TreeTraversals that has methods for traversing a tree in pre-order, in-order, and post-order. However, here we are only interested in the pre-order traversal so we will leave everything else out.

Utility classes and methods

Here’s our TreeTraversals class:

Listing 1. TreeTraversals
import java.util.function.Consumer;

public class TreeTraversals {

  public static PreOrder newPreOrder(final boolean visitNulls) {
    return new PreOrder(visitNulls);
  }

  public static abstract class AbstractTraversal implements TreeTraversal {

    protected final boolean visitNulls;

    public AbstractTraversal(final boolean visitNulls) {
      this.visitNulls = visitNulls;
    }
  }

  public static class PreOrder extends AbstractTraversal {
    public PreOrder(boolean visitNulls) {
      super(visitNulls);
    }

    @Override
    public void traverse(TreeNode node, Consumer<TreeNode> visitor) {
      if (node == null) {
        if (visitNulls) {
          visitor.accept(null);
        }

        return;
      }

      visitor.accept(node);

      traverse(node.left, visitor);
      traverse(node.right, visitor);
    }
  }
}

First solution

Now we have everything that we need in order (no pun intended) to implement our algorithm in Java:

Listing 2. Flatten (not in-place)
public void flatten(TreeNode root) {
  // Step 0
  if (root == null) {
    return;
  }

  // Step 1
  List<TreeNode> nodes = new ArrayList<>();

  // Step 2 and 2.1
  TreeTraversals.newPreOrder(false).traverse(root, nodes::add);

  // Step 3
  for (int i = 0; i < nodes.size(); i++) {
    TreeNode currentNode = nodes.get(i);

    // Step 3.1
    currentNode.right = i == nodes.size() - 1 ? null : nodes.get(i + 1);

    // Step 3.2
    currentNode.left = null;
  }
}

However, there’s a big problem with this solution: it does not flatten the tree in place, that is, without relying on auxiliary memory (the nodes list). In the next section, we devise an in-place implementation.

Second solution

If we spend some time thinking about a solution that satisfies all the expectations of the problem, fail, take a shower, think and fail again, sleep, think and fail once more…​ we will eventually find a solution.

Given this tree:

     1
   /   \
  2     5
 / \   / \
3   4 6   7

it turns out that if we:

  1. Flatten the left subtree:

           1
         /   \
        2     5
         \   / \
          3 6   7
           \
            4
  2. Flatten the right subtree:

           1
         /   \
        2     5
         \     \
          3     6
           \     \
            4     7
  3. Set the right node of the only leaf node of the flattened left subtree (i.e. node 4) to the root node of the flattened right subtree (i.e. node 5):

         1
       /
      2
       \
        3
         \
          4
           \
            5
             \
              6
               \
                7
  4. Set the right node of the root node to the root of the flattened left subtree:

         1
          \
           2
            \
             3
              \
               4
                \
                 5
                  \
                   6
                    \
                     7

We potentially have a working solution. Let’s devise a recursive algorithm:

Flatten(node):
  1. Flatten(node.left) (let's call the resulting tree TL)
  2. Flatten(node.right) (let's call the resulting tree TR)
  3. Set Leaf(TL).right = Root(TR)
  4. Set Leaf(TL).left = null
  5. Set node.left = null
  6. Set node.right = Root(TL)
  7. Return the leaf node of the right sub tree

Now let’s implement it in Java:

Listing 3. Flatten (in-place)
public TreeNode flatten(TreeNode root) {
  // edge case 1
  if (root == null) {
    return null;
  }

  // edge case 2
  if (root.left == null && root.right == null) {
    return root;
  }

  // edge case 3
  if (root.left == null) {
    return flatten(root.right);
  }

  // edge case 4
  if (root.right == null) {
    TreeNode leftLeaf = flatten(root.left);
    root.right = root.left;
    root.left = null;
    return leftLeaf;
  }

  TreeNode leftNode = root.left;
  TreeNode rightNode = root.right;

  // Step 1
  TreeNode leftLeaf = flatten(root.left);

  // Step 2
  TreeNode rightLeaf = flatten(root.right);

  // Step 3
  leftLeaf.right = rightNode;

  // Step 4
  leftLeaf.left = null;

  // Step 5
  root.left = null;

  // Step 6
  root.right = leftNode;

  // Step 7
  return rightLeaf;
}

Update 1:

A cleaner Python implementation that avoids explicit handling of edge cases is presented by /u/sltkr on Reddit. Translated to Java, it would look like this:

public void flatten(TreeNode root) {
  flatten(root, null);
}

public TreeNode flatten(TreeNode node, TreeNode next) {
  if (node == null) {
      return next;
  }

  node.right = flatten(node.left, flatten(node.right, next));
  node.left = null;
  return node;
}