How to flatten a binary tree to a linked list: recursion approach
Data structures such as binary trees, linked lists are frequently asked in a technical interview. Sometimes, interviewees are asked to transform one data structure to another. For example, convert a list to a linked list, or covert an ordered list to a binary tree, etc. Today, I’ll talk about a medium problem in leetcode.com: flatten a binary tree to linked list.
Description
Given the
root
of a binary tree, flatten the tree into a "linked list":The “linked list” should use the same
TreeNode
class where theright
child pointer points to the next node in the list and theleft
child pointer is alwaysnull
.The “linked list” should be in the same order as a pre-order traversal of the binary tree.
Example:
Input: root = [1,2,5,3,4,null,6]
Output: [1,null,2,null,3,null,4,null,5,null,6]
Solution
This problem asks to convert a binary tree to a linked list in the ‘pre-order’ traversal. Pre-order is a method to traverse a binary tree in this -> left -> right
order. In order to convert a binary tree to a linear data structure, let’s break this problem down into several sub-problems:
- Flatten the left subtree in the pre-order traversal, marked as
flattened_left_subtree
; - Flatten the right subtree in the pre-order traversal, marked as
flattened_right_subtree
; - Link the
root
,flattened_left_subtree
,flattened_right_subtree
together in the pre-order traversal.
Let’s draw a graph to see how it happens:
- Assume we are given a binary tree looks like:
- Flatten the left and right subtree:
- combine the tree in the end:
Code
The function we’re going to implement called: def flatten(self, root: TreeNode) -> None
. Let’s implement the previous three parts together:
- Step 1: flatten the left subtree: since the function
flatten
itself is flattening a given tree to a linked list, we can simply pass in the root of the left subtree as an argument,flatten(root.left)
will return the result as we want - Step 2: flatten the right subtree: similar to the previous step,
flatten(root.right)
will return the flattened right subtree. - Step 3: link the three individual parts together: 1). set the
root.left
toNone
2). find the tail offlattened_left_subtree
, and lettail.right
points to the head offlattened_right_subtree
3).letroot.right
points to the head offlattened_left_subtree
.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right def flatten(self, root: TreeNode) -> None:
"""
Do not return anything, modify root in-place instead.
"""
if not root:
return None
self.flatten(root.left)
self.flatten(root.right)
if root.left:
right = root.right
root.right = root.left
root.left = None
last = root
while last.right:
last = last.right
last.right = right
Thanks for reading!