coderplay
1/29/2013 - 9:59 AM

sorted array intersection

sorted array intersection

k=   1, iteration=   8,871,669 ns, binarysearch=  10,120,962 ns, hybrid=   2,265,420 ns
k=   2, iteration=   2,102,584 ns, binarysearch=   7,615,600 ns, hybrid=   2,070,440 ns
k=   4, iteration=   2,778,757 ns, binarysearch=   7,277,803 ns, hybrid=   2,734,949 ns
k=   8, iteration=   3,994,027 ns, binarysearch=   9,921,741 ns, hybrid=   3,974,194 ns
k=  16, iteration=   6,457,776 ns, binarysearch=  13,513,792 ns, hybrid=   6,750,129 ns
k=  32, iteration=  11,124,515 ns, binarysearch=  14,100,162 ns, hybrid=  11,092,121 ns
k=  64, iteration=  20,693,773 ns, binarysearch=  17,691,459 ns, hybrid=  20,694,222 ns
k= 128, iteration=  39,954,362 ns, binarysearch=  22,977,556 ns, hybrid=  22,778,395 ns
k= 256, iteration=  78,491,416 ns, binarysearch=  29,402,883 ns, hybrid=  29,321,736 ns
k= 512, iteration= 155,645,213 ns, binarysearch=  36,933,250 ns, hybrid=  37,019,711 ns
k=1024, iteration= 310,509,245 ns, binarysearch=  46,384,737 ns, hybrid=  46,728,891 ns
k=2048, iteration= 625,296,392 ns, binarysearch=  55,864,950 ns, hybrid=  55,948,698 ns


while k=1, jit hasn't been warmed up
import java.util.Arrays;
import java.util.Random;

public class IntArrayIntersection {

  /*
   * O(N+M)
   */
  public static int[] intersectByLiteration(final int[] lhs, final int[] rhs) {
    // worst case size for the intersection array
    int[] tmp = new int[Math.min(lhs.length, rhs.length)];
    int leftOffset = 0, rightOffset = 0, tmpOffset = 0;

    while (leftOffset < lhs.length && rightOffset < rhs.length) {
      if (lhs[leftOffset] == rhs[rightOffset]) {
        // check if element already exists
//        if (tmpOffset == 0 || lhs[leftOffset] != tmp[tmpOffset - 1]) {
          tmp[tmpOffset++] = lhs[leftOffset];
//        }
        leftOffset++;
        rightOffset++;
      } else if (lhs[leftOffset] > rhs[rightOffset]) {
        rightOffset++;
      } else if (lhs[leftOffset] < rhs[rightOffset]) {
        leftOffset++;
      }
    }

    return Arrays.copyOfRange(tmp, 0, tmpOffset);
  }
  
  /**
   * O(NlogM), while M is the bigger array
   */
  public static int[] intersectByBinarySearch(final int[] lhs, final int[] rhs) {
    // worst case size for the intersection array
    int[] tmp = new int[Math.min(lhs.length, rhs.length)];
    int tmpOffset = 0;
    for (int element : lhs) {
      if (binarySearch(rhs, element) >= 0) {
//        if (tmpOffset == 0 || tmp[tmpOffset - 1] != element) {
          tmp[tmpOffset++] = element;
//        }
      }
    }
    return Arrays.copyOfRange(tmp, 0, tmpOffset);
  }
  
  private static int binarySearch(int[] a, int key) {
    return binarySearch(a, 0, a.length, key);
  }

  private static int binarySearch(int[] a, int fromIndex, int toIndex, int key) {
    int low = fromIndex;
    int high = toIndex - 1;

    while (low <= high) {
      int mid = (low + high) >>> 1;
      int midVal = a[mid];

      if (midVal < key)
        low = mid + 1;
      else if (midVal > key)
        high = mid - 1;
      else
        return mid; // key found
    }
    return -(low + 1); // key not found.
  }
  

  static final int[] skips = { 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377,
  610, 987, 1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025,
  121393, 196418, 317811 };
//static final int[] skips = { 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048,
//    4096, 8192 };
  public static int[] intersectBySkipList(final int[] lhs, 
      final int[] rhs) {
    int[] tmp = new int[Math.min(lhs.length, rhs.length)];
    int leftOffset = 0, rightOffset = 0, tmpOffset = 0;

    int leftSize = lhs.length;
    int rightSize = rhs.length;

    while (leftOffset < leftSize && rightOffset < rightSize) {
      if (lhs[leftOffset] < rhs[rightOffset]) {
        int start = leftOffset + skips[0];
        if (start < leftSize && lhs[start] <= rhs[rightOffset]) {
          int i = 0;
          do {
            leftOffset = start;
            start += skips[i++];
          } while (start < leftSize && lhs[start] <= rhs[rightOffset]);
        } else {
          leftOffset++;
        }
      }  else if (rhs[rightOffset] < lhs[leftOffset]) {
        int start = rightOffset + skips[0];
        if (start < rightSize && rhs[start] <= lhs[leftOffset]) {
          int i = 0;
          do {
            rightOffset = start;
            start += skips[i++];
          } while (start < rightSize && rhs[start] <= lhs[leftOffset]);
        } else {
          rightOffset++;
        }
      } else {
        tmp[tmpOffset++] = lhs[leftOffset];
        leftOffset++;
        rightOffset++;
      } 
    }

    return Arrays.copyOfRange(tmp, 0, tmpOffset);
  }
  
  
  public static int[] intersectByHybrid(final int[] left, 
      final int[] right) {
    int[] lhs, rhs;
    if (left.length <= right.length) {
      lhs = left;
      rhs = right;
    } else {
      lhs = right;
      rhs = left;
    }
    if ((lhs.length << 6) < rhs.length) {
      return intersectByBinarySearch(lhs, rhs);
    } else {
      return intersectByLiteration(lhs, rhs);
    }
  }

  public static void main(String[] args) {
    int N = 100000, M;
    int[] A = new int[N];

    Random rand = new Random();

    for(int k = 1; k <= 2048; k <<= 1 ) {
      M = k * N;
      int[] B = new int[M];

      for (int i = 0; i < M; i++) {
        B[i] = rand.nextInt();
        if (i < N)
          A[i] = rand.nextInt();
      }

      Arrays.sort(A);
      Arrays.sort(B);

      long tStart = System.nanoTime();
      int[] c1 = intersectByLiteration(A, B);
      long t1 = System.nanoTime() - tStart;

      tStart = System.nanoTime();
      int[] c2 = intersectByBinarySearch(A, B);
      long t2 = System.nanoTime() - tStart;
      
      tStart = System.nanoTime();
      int[] c3 = intersectByHybrid(A, B);
      long t3 = System.nanoTime() - tStart;
      
//      tStart = System.nanoTime();
//      int[] c3 = intersectBySkipList(A, B);
//      long t3 = System.nanoTime() - tStart;

//      System.out.println("k=" + k + ", cost time t1=" + t1 + "ns, t2=" + t2
//          + "ns, t3=" + t3);
//
//      System.out.println("c1=" + c1.length + ", c2=" + c2.length + ", c3=" + c3.length);
//      
//      System.out.println("k=" + k + ", cost time t1=" + t1 + "ns, t2=" + t2
//          + "ns");
//
//      System.out.println("c1=" + c1.length + ", c2=" + c2.length);
      System.out.format(
          "k=%4d, iteration=%,12d ns, binarysearch=%,12d ns, hybrid=%,12d ns%n",
          k, t1, t2, t3);
//      System.out.println("c1=" + c1.length + ", c2=" + c2.length + ", c3="
//          + c3.length);
    }
  }
}