/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
public class Solution {
private void helper(TreeNode root, int sum, int[] res, Map<Integer, Integer> map, List<Integer> prefixSum) {
if (root == null) return;
int curSum = root.val + prefixSum.get(prefixSum.size() - 1);
int diff = curSum - sum;
if (map.containsKey(diff)) {
res[0] += map.get(diff);
}
prefixSum.add(curSum);
map.put(curSum, map.getOrDefault(curSum, 0) + 1);
helper(root.left, sum, res, map, prefixSum);
helper(root.right, sum, res, map, prefixSum);
map.put(curSum, map.getOrDefault(curSum, 0) - 1);
prefixSum.remove(prefixSum.size() - 1);
}
public int pathSum(TreeNode root, int sum) {
int[] res = new int[] {0};
if (root == null) return res[0];
Map<Integer, Integer> map = new HashMap<>();
List<Integer> prefixSum = new ArrayList<>();
map.put(0, 1);
if (root.val == sum) res[0] += 1;
prefixSum.add(root.val);
map.put(root.val, map.getOrDefault(root.val, 0) + 1); //[0,1,1] 1 Expected 4
helper(root.left, sum, res, map, prefixSum);
helper(root.right, sum, res, map, prefixSum);
return res[0];
}
}
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
public class Solution {
public int pathSum(TreeNode root, int sum) {
if (root == null) {
return 0;
}
Map<Integer, Integer> map = new HashMap<>();
map.put(0, 1);
return findPathSum(root, 0, sum, map);
}
private int findPathSum(TreeNode curr, int sum, int target, Map<Integer, Integer> map) {
if (curr == null) {
return 0;
}
// update the prefix sum by adding the current val
sum += curr.val;
// get the number of valid path, ended by the current node
int numPathToCurr = map.getOrDefault(sum-target, 0);
// update the map with the current sum, so the map is good to be passed to the next recursion
map.put(sum, map.getOrDefault(sum, 0) + 1);
// add the 3 parts discussed in 8. together
int res = numPathToCurr + findPathSum(curr.left, sum, target, map)
+ findPathSum(curr.right, sum, target, map);
// restore the map, as the recursion goes from the bottom to the top
map.put(sum, map.get(sum) - 1);
return res;
}
}