15. 3Sum
Problem
Given an integer array nums, return all the triplets [nums[i], nums[j], nums[k]]
such that i != j
, i != k
, and j != k
, and nums[i] + nums[j] + nums[k] == 0
.
Notice that the solution set must not contain duplicate triplets.
Example 1:
Input: nums = [-1,0,1,2,-1,-4] Output: [[-1,-1,2],[-1,0,1]] Explanation: nums[0] + nums[1] + nums[2] = (-1) + 0 + 1 = 0. nums[1] + nums[2] + nums[4] = 0 + 1 + (-1) = 0. nums[0] + nums[3] + nums[4] = (-1) + 2 + (-1) = 0. The distinct triplets are [-1,0,1] and [-1,-1,2]. Notice that the order of the output and the order of the triplets does not matter.
Example 2:
Input: nums = [0,1,1] Output: [] Explanation: The only possible triplet does not sum up to 0.
Example 3:
Input: nums = [0,0,0] Output: [[0,0,0]] Explanation: The only possible triplet sums up to 0.
Constraints:
3 <= nums.length <= 3000
-105 <= nums[i] <= 105
Solution
We will follow the same two pointers pattern as in Two Sum II. It requires the array to be sorted, so we'll do that first. As our BCR is O(n2)\mathcal{O}(n^2)O(n2), sorting the array would not change the overall time complexity.
To make sure the result contains unique triplets, we need to skip duplicate values. It is easy to do because repeating values are next to each other in a sorted array.
After sorting the array, we move our pivot element nums[i]
and analyze elements to its right. We find all pairs whose sum is equal -nums[i]
using the two pointers pattern, so that the sum of the pivot element (nums[i]
) and the pair (-nums[i]
) is equal to zero.
As a quick refresher, the pointers are initially set to the first and the last element respectively. We compare the sum of these two elements to the target. If it is smaller, we increment the lower pointer lo
. Otherwise, we decrement the higher pointer hi
. Thus, the sum always moves toward the target, and we "prune" pairs that would move it further away. Again, this works only if the array is sorted. Head to the Two Sum II solution for the detailed explanation.
class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
res = []
nums.sort()
for i in range(len(nums)):
if nums[i] > 0:
break
if i == 0 or nums[i - 1] != nums[i]:
self.twoSumII(nums, i, res)
return res
def twoSumII(self, nums: List[int], i: int, res: List[List[int]]):
lo, hi = i + 1, len(nums) - 1
while (lo < hi):
sum = nums[i] + nums[lo] + nums[hi]
if sum < 0:
lo += 1
elif sum > 0:
hi -= 1
else:
res.append([nums[i], nums[lo], nums[hi]])
lo += 1
hi -= 1
while lo < hi and nums[lo] == nums[lo - 1]:
lo += 1