How to flatten a binary tree to a linked list: recursion approach

Wangyy
3 min readMar 14, 2021

Data structures such as binary trees, linked lists are frequently asked in a technical interview. Sometimes, interviewees are asked to transform one data structure to another. For example, convert a list to a linked list, or covert an ordered list to a binary tree, etc. Today, I’ll talk about a medium problem in leetcode.com: flatten a binary tree to linked list.

Description

Given the root of a binary tree, flatten the tree into a "linked list":

The “linked list” should use the same TreeNode class where the right child pointer points to the next node in the list and the left child pointer is always null.

The “linked list” should be in the same order as a pre-order traversal of the binary tree.

Example:

Input: root = [1,2,5,3,4,null,6]
Output: [1,null,2,null,3,null,4,null,5,null,6]

Solution

This problem asks to convert a binary tree to a linked list in the ‘pre-order’ traversal. Pre-order is a method to traverse a binary tree in this -> left -> right order. In order to convert a binary tree to a linear data structure, let’s break this problem down into several sub-problems:

  • Flatten the left subtree in the pre-order traversal, marked as flattened_left_subtree;
  • Flatten the right subtree in the pre-order traversal, marked as flattened_right_subtree;
  • Link the root, flattened_left_subtree , flattened_right_subtree together in the pre-order traversal.

Let’s draw a graph to see how it happens:

  • Assume we are given a binary tree looks like:
  • Flatten the left and right subtree:
  • combine the tree in the end:

Code

The function we’re going to implement called: def flatten(self, root: TreeNode) -> None . Let’s implement the previous three parts together:

  • Step 1: flatten the left subtree: since the function flatten itself is flattening a given tree to a linked list, we can simply pass in the root of the left subtree as an argument, flatten(root.left) will return the result as we want
  • Step 2: flatten the right subtree: similar to the previous step, flatten(root.right) will return the flattened right subtree.
  • Step 3: link the three individual parts together: 1). set the root.left to None 2). find the tail of flattened_left_subtree , and let tail.right points to the head of flattened_right_subtree 3).let root.right points to the head of flattened_left_subtree .
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def flatten(self, root: TreeNode) -> None:
"""
Do not return anything, modify root in-place instead.
"""
if not root:
return None
self.flatten(root.left)
self.flatten(root.right)

if root.left:
right = root.right
root.right = root.left
root.left = None
last = root
while last.right:
last = last.right
last.right = right

Thanks for reading!

Reference

--

--