coderplay
9/13/2012 - 3:59 AM

SegmentedEratosthenes

SegmentedEratosthenes

import java.util.BitSet;

/**
 * Shamelessly copied some code from <a
 * href="mailto:zhong.lunfu@gmail.com">zhongl<a>.
 * 
 * <p>
 * The basic idea of a segmented sieve is to choose the sieving primes less than
 * the square root of n, choose a reasonably large segment size that
 * nevertheless fits in memory, and then sieve each of the segments in turn,
 * starting with the smallest. At the first segment, the smallest multiple of
 * each sieving prime that is within the segment is calculated, then multiples
 * of the sieving prime are marked as composite in the normal way; when all the
 * sieving primes have been used, the remaining unmarked numbers in the segment
 * are prime. Then, for the next segment, for each sieving prime you already
 * know the first multiple in the current segment (it was the multiple that
 * ended the sieving for that prime in the prior segment), so you sieve on each
 * sieving prime, and so on until you are finished.
 * 
 * <p>
 * Algorithm description: <a href=
 * "http://programmingpraxis.com/2010/02/05/segmented-sieve-of-eratosthenes/">
 * Segmented Sieve Of Eratosthenes</a>.
 * 
 * @version version for Coding4Fun
 * @author Min Zhou(coderplay@gmail.com)
 */
public class SegmentedEratosthenes {

  /**
   * Usage: java EratosthenesPrime N C
   */
  public static void main(String[] args) throws Throwable {
    if (args.length != 2)
      throw new IllegalArgumentException("Usage: java EratosthenesPrime N C");

    final int number = Integer.parseInt(args[0]);
    final int count = Integer.parseInt(args[1]);

    if (number <= 0 || count <= 0)
      throw new IllegalArgumentException("N and C should be natural numbers");

    countAndPrintPrime(number, count);
  }

  private static void countAndPrintPrime(final int number, final int count) {
    long start = System.nanoTime();
    Primes primes = new PrimeDetector(number).detect();
    System.out.println("Cost: " + (System.nanoTime() - start) + "ns");

    final int size = primes.count() % 2 == 0 ? count * 2 : count * 2 - 1;
    String result = primes.outputMiddleOf(size);
    System.out.println(number + " " + count + ": " + result);
  }
  
}

class PrimeDetector {
  private static final int segmentSize = 128 * 1024;

  private final int number;

  private void processSingleSegment(final FastBitSet flags,
                                    final int from,
                                    final int to) {
    // offset of the first number within current segment
    final int offset = ((from + 1) >> 1) + 1;

    for (int i = 3; i * i <= to; i += 2) {
      // skip multiples of 3, 5, 7, 11, 13
      if (  (i >= 3  * 3  && i % 3  == 0)
         || (i >= 5  * 5  && i % 5  == 0)
         || (i >= 7  * 7  && i % 7  == 0)
         || (i >= 11 * 11 && i % 11 == 0)
         || (i >= 13 * 13 && i % 13 == 0) 
         || (i >= 17 * 17 && i % 17 == 0))
        continue;

      // skip numbers before current segment
      int start = ((from + i - 1) / i) * i;
      // start value should at least be square of i
      if (start < i * i)
        start = i * i;
      // start value must be odd
      if ((start & 1) == 0)
        start += i;

      // mark all odd non-primes
      for (int j = start; j <= to; j += (i << 1)) {
        flags.fastSet(offset + ((j - from) >> 1));
      }
    }
  }

  public PrimeDetector(int number) {
      this.number = number;
  }

  public Primes detect() {
    // only consider odd numbers to save memory footprint
    // 2 is a special prime number that is even, so plus one memory slot
    final int memorySize = (number > 1) ? ((number + 1) >> 1) + 1 : 1;
    final FastBitSet filter = new FastBitSet(memorySize);

    for (int from = 2; from <= number; from += segmentSize) {
      int to = from + segmentSize;
      if (to > number)
        to = number;
      processSingleSegment(filter, from, to);
    }

    final int numPrimes = memorySize - filter.cardinality();
    return new Primes(filter, numPrimes);
  }
}

class Primes {

  private final FastBitSet filter;
  private final int numPrimes;

  Primes(FastBitSet filter, int numOfPrimes) {
      this.filter = filter;
      this.numPrimes = numOfPrimes;
  }

  public int count() { return numPrimes;}

  public String outputMiddleOf(int outputNum) {
    if (numPrimes < outputNum) return output(0, numPrimes);
    return output((numPrimes - outputNum) >> 1, outputNum);
  }

  private String output(int begin, int size) {
    StringBuilder sb = new StringBuilder();

    for (int i = 0, index = 0; i < begin + size; i++, index++) {
      index = filter.nextClearBit(index);
      if (i < begin)
        continue;
      // 1, 2 are special primes
      int prime = (index != 0 && index != 1) ? 
          (index << 1 ) - 1 :
          (index == 0) ? 1 : 2; 
      sb.append(prime).append(" ");
    }
    return sb.toString();
  }
}

class FastBitSet {
  /* Used to shift left or right for a partial word mask */
  private static final long WORD_MASK = 0xffffffffffffffffL;

  protected long[] bits;
  protected int wlen; // number of words (elements) used in the array

  /**
   * Constructs an OpenBitSet large enough to hold numBits.
   * 
   * @param numBits
   */
  public FastBitSet(long numBits) {
    bits = new long[bits2words(numBits)];
    wlen = bits.length;
  }

  /** returns the number of 64 bit words it would take to hold numBits */
  public static int bits2words(long numBits) {
    return (int) (((numBits - 1) >>> 6) + 1);
  }
  
  /**
   * Sets the bit at the specified index. The index should be less than the
   * FastBitSet size.
   */
  public void fastSet(int index) {
    int wordNum = index >> 6; // div 64
    int bit = index & 0x3f; // mod 64
    long bitmask = 1L << bit;
    bits[wordNum] |= bitmask;
  }
   
  public int cardinality() {
    int sum = 0;
    for (int i = 0; i < wlen; i++)
      sum += Long.bitCount(bits[i]);
    return sum;
  }

  /**
   * Returns the index of the first set bit starting at the index specified. -1
   * is returned if there are no more set bits.
   */
  public int nextClearBit(int index) {
    int i = index >> 6;
    if (i >= wlen)
      return index;
    
    long word = ~bits[i] & (WORD_MASK << index);
    while (true) {
      if (word != 0)
        return (i << 6) + Long.numberOfTrailingZeros(word);
      if (++i == wlen)
        return wlen << 6;
      word = ~bits[i];
    }
  }
}