# The problem with the CodeSprint 2 complex is too slow

The original InterviewStreet Codesprint raises the question of counting the number of ones in two's complement representations of numbers between a and b inclusive.I was able to pass all the test cases for precision using iteration, but I was only able to go through twice at the right time. There was a hint that indicated a recurring relationship was detected, so I switched to recursion, but ended up with the same amount of time. So can anyone find a faster way to do this than the code I have provided? The first number of the input file is the test cases in the file. After the code, I have provided a sample input file.

```
import java.util.Scanner;
public class Solution {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int numCases = scanner.nextInt();
for (int i = 0; i < numCases; i++) {
int a = scanner.nextInt();
int b = scanner.nextInt();
System.out.println(count(a, b));
}
}
/**
* Returns the number of ones between a and b inclusive
*/
public static int count(int a, int b) {
int count = 0;
for (int i = a; i <= b; i++) {
if (i < 0)
count += (32 - countOnes((-i) - 1, 0));
else
count += countOnes(i, 0);
}
return count;
}
/**
* Returns the number of ones in a
*/
public static int countOnes(int a, int count) {
if (a == 0)
return count;
if (a % 2 == 0)
return countOnes(a / 2, count);
else
return countOnes((a - 1) / 2, count + 1);
}
}
```

Input:

```
3
-2 0
-3 4
-1 4
Output:
63
99
37
```

source to share

The first step is to replace

```
public static int countOnes(int a, int count) {
if (a == 0)
return count;
if (a % 2 == 0)
return countOnes(a / 2, count);
else
return countOnes((a - 1) / 2, count + 1);
}
```

which repeats down to log depth _{ 2} a, with a faster implementation, like the famous bit-twiddling

```
public static int popCount(int n) {
// count the set bits in each bit-pair
// 11 -> 10, 10 -> 01, 0* -> 0*
n -= (n >>> 1) & 0x55555555;
// count bits in each nibble
n = ((n >>> 2) & 0x33333333) + (n & 0x33333333);
// count bits in each byte
n = ((n >> 4) & 0x0F0F0F0F) + (n & 0x0F0F0F0F);
// accumulate the counts in the highest byte and shift
return (0x01010101 * n) >> 24;
// Java guarantees wrap-around, so we can use int here,
// in C, one would need to use unsigned or a 64-bit type
// to avoid undefined behaviour
}
```

which uses four shifts, five bitwise, etc., one subtraction, two's complement and one multiplication for just thirteen very cheap instructions.

But if the ranges are very small, you can do **much** better than counting the bits of each individual number.

Consider non-negative numbers first. Numbers from 0 to 2 ^{k} -1 have up to bits `k`

. Each bit is set in exactly half of them, so the total number of bits is `k*2^(k-1)`

. Now let it be `2^k <= a < 2^(k+1)`

. The total number of bits in numbers `0 <= n <= a`

is the sum of the bits in numbers `0 <= n < 2^k`

and the bits in numbers `2^k <= n <= a`

. The first count, as we saw above, is `k*2^(k-1)`

. In the second part, we have numbers `a - 2^k + 1`

, each with a set of 2 ^{k-} bits, and ignoring the leading bit, their bits are the same as in numbers `0 <= n <= (a - 2^k)`

, so

`totalBits(a) = k*2^(k-1) + (a - 2^k + 1) + totalBits(a - 2^k)`

Now for negative numbers. In double's complement `-(n+1) = ~n`

, so numbers `-a <= n <= -1`

are complements of numbers `0 <= m <= (a-1)`

, and the total number of specified bits in numbers `-a <= n <= -1`

is `a*32 - totalBits(a-1)`

.

For the total number of bits in the range, `a <= n <= b`

we have to add or subtract, depending on whether both ends of the range have opposite signs or the same signs.

```
// if n >= 0, return the total of set bits for
// the numbers 0 <= k <= n
// if n < 0, return the total of set bits for
// the numbers n <= k <= -1
public static long totalBits(int n){
if (n < 0) {
long a = -(long)n;
return (a*32 - totalBits((int)(a-1)));
}
if (n < 3) return n;
int lg = 0, mask = n;
// find the highest set bit in n and its position
while(mask > 1){
++lg;
mask >>= 1;
}
mask = 1 << lg;
// total bit count for 0 <= k < 2^lg
long total = 1L << lg-1;
total *= lg;
// add number of 2^lg bits
total += n+1-mask;
// add number of other bits for 2^lg <= k <= n
total += totalBits(n-mask);
return total;
}
// return total set bits for the numbers a <= n <= b
public static long totalBits(int a, int b) {
if (b < a) throw new IllegalArgumentException("Invalid range");
if (a == b) return popCount(a);
if (b == 0) return totalBits(a);
if (b < 0) return totalBits(a) - totalBits(b+1);
if (a == 0) return totalBits(b);
if (a > 0) return totalBits(b) - totalBits(a-1);
// Now a < 0 < b
return totalBits(a) + totalBits(b);
}
```

source to share