Given a root node reference of a BST and a key, delete the node with the given key in the BST. Return the root node reference (possibly updated) of the BST.
Basically, the deletion can be divided into two stages:
Search for a node to remove.
If the node is found, delete the node.
Follow up: Can you solve it with time complexity O(height of tree)?
Example 1:
Input: root = [5,3,6,2,4,null,7], key = 3
Output: [5,4,6,2,null,null,7]
Explanation: Given key to delete is 3. So we find the node with value 3 and delete it.
One valid answer is [5,4,6,2,null,null,7], shown in the above BST.
Please notice that another valid answer is [5,2,6,null,4,null,7] and it's also accepted.
Example 2:
Input: root = [5,3,6,2,4,null,7], key = 0
Output: [5,3,6,2,4,null,7]
Explanation: The tree does not contain a node with value = 0.
Example 3:
Input: root = [], key = 0
Output: []
Constraints:
The number of nodes in the tree is in the range [0, 104].
-105 <= Node.val <= 105
Each node has a unique value.
root is a valid binary search tree.
-105 <= key <= 105
Solution
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
TreeNode* deleteNode(TreeNode* root, int key) {
if (root == nullptr) return root;
TreeNode* par = nullptr, *cur = root;
while (cur != nullptr) {
if (cur->val > key) {
par = cur;
cur = cur->left;
} else if (cur->val < key) {
par = cur;
cur = cur->right;
} else {
break;
}
}
if (cur == nullptr) return root;
if (cur->right == nullptr) {
cur = cur->left;
} else {
if (cur->right->left == nullptr) {
cur->right->left = cur->left;
cur = cur->right;
} else {
TreeNode *new_cur = cur->right;
while (new_cur->left->left != nullptr) new_cur = new_cur->left;
TreeNode *temp = new_cur->left;
new_cur->left = temp->right;
new_cur = temp;
new_cur->right = cur->right;
new_cur->left = cur->left;
cur = new_cur;
}
}
if (par == nullptr) return cur;
if (cur != nullptr) {
if (par->val > cur->val) par->left = cur;
else par->right = cur;
} else {
if (par->left != nullptr && par->left->val == key) par->left = cur;
else par->right = cur;
}
return root;
}
};