Count of Smaller Numbers After Self with Fenwick Tree (Binary Indexed Tree)
Home | Coding Interviews | Complex Data Structures | Count of Smaller Numbers After Self with Fenwick Tree (Binary Indexed Tree)
Given an integer array nums, return an integer array counts where counts[i] is the number of smaller elements to the right of nums[i].
Note there are multiple ways to approach this problem. Some solutions use modified merge sort but this is a great example to introduce the fenwick tree.
//https://en.wikipedia.org/wiki/Fenwick_tree
class Solution {
public List<Integer> countSmaller(int[] a) {
int n = a.length;
int []arr = compressedArray(a);
int val= arr.length;
FenwickTree BIT = new FenwickTree(val);
List<Integer>ls = new ArrayList<>();
for(int i = n; i >= 1; i--) {
ls.add(BIT.query(arr[i] - 1));
BIT.update(arr[i], 1);
}
Collections.reverse(ls);
return ls;
}
private int[]compressedArray (int []a){
int n = a.length;
Map<Integer, Integer> map = new TreeMap<>();
for(int i : a) {
map.put(i, 0);
}
int val = 1;
for(var entry : map.entrySet()) {
map.put(entry.getKey(), val);
val++;
}
int arr[] = new int[n+1];
for(int i =1 ; i <=n; i++) {
arr[i] = map.get(a[i-1]);
}
return arr;
}
}
class FenwickTree {
int[] bit;
int n;
FenwickTree(int n) {
this.n = n;
this.bit = new int[n + 1];
}
public void update(int i, int val) {
while (i < bit.length) {
bit[i] += val;
i += (i & (-i));
}
}
public int query(int i) {
int sum = 0;
while (i > 0) {
sum += bit[i];
i -= (i & (-i));
}
return sum;
}
public int rangeSum(int l, int r) {
return query(r) - query(l - 1);
}
}
Explanation of the Fenwick Tree (also called Binary Index Tree or BIT):
and here is a merge sort solution for comparison:
int[] count;
public List<Integer> countSmaller(int[] nums) {
List<Integer> res = new ArrayList<Integer>();
count = new int[nums.length];
int[] indexes = new int[nums.length];
for(int i = 0; i < nums.length; i++){
indexes[i] = i;
}
mergesort(nums, indexes, 0, nums.length - 1);
for(int i = 0; i < count.length; i++){
res.add(count[i]);
}
return res;
}
private void mergesort(int[] nums, int[] indexes, int start, int end){
if(end <= start){
return;
}
int mid = (start + end) / 2;
mergesort(nums, indexes, start, mid);
mergesort(nums, indexes, mid + 1, end);
merge(nums, indexes, start, end);
}
private void merge(int[] nums, int[] indexes, int start, int end){
int mid = (start + end) / 2;
int left_index = start;
int right_index = mid+1;
int rightcount = 0;
int[] new_indexes = new int[end - start + 1];
int sort_index = 0;
while(left_index <= mid && right_index <= end){
if(nums[indexes[right_index]] < nums[indexes[left_index]]){
new_indexes[sort_index] = indexes[right_index];
rightcount++;
right_index++;
}else{
new_indexes[sort_index] = indexes[left_index];
count[indexes[left_index]] += rightcount;
left_index++;
}
sort_index++;
}
while(left_index <= mid){
new_indexes[sort_index] = indexes[left_index];
count[indexes[left_index]] += rightcount;
left_index++;
sort_index++;
}
while(right_index <= end){
new_indexes[sort_index++] = indexes[right_index++];
}
for(int i = start; i <= end; i++){
indexes[i] = new_indexes[i - start];
}
}
Posted by Jamie Meyer 15 days ago