ChunMinChang
8/19/2017 - 3:20 AM

Exponentiation by squaring

Exponentiation by squaring #math #recursion

#include "exp.h"
#include <cassert>  // assert
#include <ctime>    // std::clock
#include <cstdint>  // uint64_t
#include <iostream> // std::cout, std::cin, std::endl

typedef struct Time
{
  double cpu;
  double wall;
} Time;

typedef struct Result
{
  uint64_t value;
  Time time;
} Result;

template <typename Function, typename... Args>
Result calculate(Function aFunction, Args&&... aArgs)
{
  std::clock_t cpuStart = std::clock();
  auto wallStart = std::chrono::high_resolution_clock::now();

  uint64_t r = aFunction(std::forward<Args>(aArgs)...);

  std::clock_t cpuEnd = std::clock();
  auto wallEnd = std::chrono::high_resolution_clock::now();
  return Result { r, Time { 1000.0 * (cpuEnd - cpuStart) / CLOCKS_PER_SEC,
                            std::chrono::duration<double, std::milli>
                              (wallEnd - wallStart).count()
                          }
                };
}

int main()
{
  unsigned int k = 0, n = 0;
  std::cout << "Enter k, n to get k^n: ";
  std::cin >> k >> n;
  Result p1 = calculate(pow1, k, n);
  Result p2 = calculate(pow2, k, n);
  Result p3 = calculate(pow3, k, n);
  Result p4 = calculate(pow4, k, n);
  Result p5 = calculate(pow5, k, n);
  Result p6 = calculate(pow6, k, n);
  Result p7 = calculate(pow7, k, n);
  Result p8 = calculate(pow8, k, n);
  Result p9 = calculate(pow9, k, n);
  assert(p1.value == p2.value &&
         p2.value == p3.value &&
         p3.value == p4.value &&
         p4.value == p5.value &&
         p5.value == p6.value &&
         p6.value == p7.value &&
         p7.value == p8.value &&
         p8.value == p9.value);
  std::cout << k << "^" << n << " = " << p1.value << std::endl;

  std::cout << "\nTime\n-----" << std::endl;
  std::cout << "1 \tcpu: " << p1.time.cpu << "ms, wall: " << p1.time.wall << " ms" << std::endl;
  std::cout << "2 \tcpu: " << p2.time.cpu << "ms, wall: " << p2.time.wall << " ms" << std::endl;
  std::cout << "3 \tcpu: " << p3.time.cpu << "ms, wall: " << p3.time.wall << " ms" << std::endl;
  std::cout << "4 \tcpu: " << p4.time.cpu << "ms, wall: " << p4.time.wall << " ms" << std::endl;
  std::cout << "5 \tcpu: " << p5.time.cpu << "ms, wall: " << p5.time.wall << " ms" << std::endl;
  std::cout << "6 \tcpu: " << p6.time.cpu << "ms, wall: " << p6.time.wall << " ms" << std::endl;
  std::cout << "7 \tcpu: " << p7.time.cpu << "ms, wall: " << p7.time.wall << " ms" << std::endl;
  std::cout << "8 \tcpu: " << p8.time.cpu << "ms, wall: " << p8.time.wall << " ms" << std::endl;
  std::cout << "9 \tcpu: " << p9.time.cpu << "ms, wall: " << p9.time.wall << " ms" << std::endl;

  return 0;
}
CC=g++
CFLAGS=-Wall --std=c++11

SOURCES=test.cpp\
        exp.cpp
OBJECTS=$(SOURCES:.cpp=.o)

# Name of the executable program
EXECUTABLE=run_test

all: $(EXECUTABLE)

# $@ is same as $(EXECUTABLE)
$(EXECUTABLE): $(OBJECTS)
	$(CC) $(CFLAGS) $(OBJECTS) -o $@

# ".cpp.o" is a suffix rule telling make how to turn file.cpp into file.o
# for an arbitrary file.
#
# $< is an automatic variable referencing the source file,
# file.cpp in the case of the suffix rule.
#
# $@ is an automatic variable referencing the target file, file.o.
# file.o in the case of the suffix rule.
#
# Use -c to generatethe .o file
.cpp.o:
	$(CC) -c $(CFLAGS) $< -o $@

clean:
	rm *.o $(EXECUTABLE)
// Exponentiation by squaring
#ifndef EXP_BY_SQUARING
#define EXP_BY_SQUARING

#include <cstdint>  // uint64_t

uint64_t pow1(unsigned int k, unsigned int n);
uint64_t pow2(unsigned int k, unsigned int n);
uint64_t pow3(unsigned int k, unsigned int n);
uint64_t pow4(unsigned int k, unsigned int n);
uint64_t pow5(unsigned int k, unsigned int n);
uint64_t pow6(unsigned int k, unsigned int n);
uint64_t pow7(unsigned int k, unsigned int n);
uint64_t pow8(unsigned int k, unsigned int n);
uint64_t pow9(unsigned int k, unsigned int n);

#endif // EXP_BY_SQUARING
#include "exp.h"
#include <stack>

// Calculate the exponentiation by squaring:
//   k ^ n = (k ^ (n / 2)) ^ 2           if n is even
//        or k * (k ^ ((n - 1) / 2))^ 2  if n is odd
//   or
//   k ^ n = (k ^ 2) ^ (n / 2)           if n is even
//        or k * (k ^ 2) ^ ((n - 1) / 2) if n is odd
//
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring
uint64_t pow1(unsigned int k, unsigned int n)
{
  if (!n) {
    return 1;
  }

  if (n % 2) {
    return k * pow1(k, (n - 1) / 2) * pow1(k, (n - 1) / 2);
  } else {
    return pow1(k, n / 2) * pow1(k, n / 2);
  }
}

uint64_t pow2(unsigned int k, unsigned int n)
{
  if (!n) {
    return 1;
  }

  if (n % 2) {
    return k * pow2(k * k, (n - 1) / 2);
  } else {
    return pow2(k * k, n / 2);
  }
}

uint64_t pow3(unsigned int k, unsigned int n)
{
  if (!n) {
    return 1;
  }

  // uint64_t x = pow3(k, n / 2);
  // x *= x;
  uint64_t x = pow3(k * k, n / 2);
  // so x = (k^2)^(n/2),      if n is even
  //     or (k^2)^((n-1)/2),  if n is odd

  if (n % 2) {
    x = x * k; // so x = k * (k^2)^((n-1)/2)
  }
  return x;
}

uint64_t pow4(unsigned int k, unsigned int n)
{
  if (!n) {
    return 1;
  }

  uint64_t x = pow4(k * k, /*n / 2*/n >> 1);
  return (/*n % 2*/ n & 1) ? x * k : x;
}

uint64_t pow5(unsigned int k, unsigned int n)
{
  if (!n) {
    return 1;
  }

  uint64_t r = 1; // The remaining part for the squaring.
  while (n > 1) {
    if (/*n % 2*/ n & 1) {
      r *= k;
      k *= k;
      n = (n - 1) / 2;
    } else {
      k *= k;
      n = n / 2;
    }
  }

  return r * k;
}

uint64_t pow6(unsigned int k, unsigned int n)
{
  // The `r` should be the remaining part for the squaring(in pow5).
  // However, we notice that the `r * k` is duplicated in pow5. We will get
  // the answer by `r * k` when n = 1. If we keep looping when n = 1,
  // `r` is our answer. Nevertheless, we will waste time to do `k *= k`
  // when n = 1.
  uint64_t r = 1;

  while (n) {
    if (/*n % 2*/ n & 1) {
      r *= k;
      k *= k;
      n = (n - 1) / 2;
    } else {
      k *= k;
      n = n / 2;
    }
  }

  return r;
}

uint64_t pow7(unsigned int k, unsigned int n)
{
  uint64_t r = 1;
  while (n) {
    if (/*n % 2*/ n & 1) {
      r *= k;
    }
    k *= k;
    /*n /= 2*/n >>= 1;
  }

  return r;
}

// uint64_t pow8(unsigned int k, unsigned int n)
// {
//   std::stack<unsigned int> s;
//   while(n) {
//     s.push(n);
//     n >>= 1;
//   }

//   uint64_t a = 1; // a = k^0 = 1
//   while (!s.empty()) {
//     unsigned int x = s.top(); s.pop();
//     // Let y = floor(x/2), y = x/2 if x is even, y = (x-1)/2 if x is odd.
//     // then a = k^y now.
//     if (x % 2) {      // x is odd:
//       a = k * a * a;  //   a = k^x = k^(2y+1) = k * (k^y)^2
//     } else {          // x is even:
//       a = a * a;      //   a = k^x = k^(2y) = (k^y)^2
//     }
//   }

//   return a;
// }

uint64_t pow8(unsigned int k, unsigned int n)
{
  std::stack<unsigned int> s;
  while(n) {
    s.push(n);
    n >>= 1;
  }

  uint64_t a = 1; // a = k^0 = 1
  while (!s.empty()) {
    unsigned int x = s.top(); s.pop();
    // Let y = floor(x/2), y = x/2 if x is even, y = (x-1)/2 if x is odd.
    // then a = k^y now.
    a *= a; // a = (k^y)^2 = k^(2y)
                  // x is even:
                  //   a = k^x = k^(2y)
    if (x % 2) {  // x is odd:
      a *= k;     //   a = k^x = k^(2y+1) = k * k^(2y)
    }
  }

  return a;
}

// uint64_t pow9(unsigned int k, unsigned int n)
// {
//   // The position of the highest bit of n.
//   // So we need to loop `h` times to get the answer.
//   // Example: n = (Dec)50 = (Bin)00110010, then h = 6.
//   //                               ^ 6th bit from right side
//   unsigned int h = 0;
//   for (unsigned int i = n ; i ; ++h, i >>= 1);

//   uint64_t a = 1; // a = k^0 = 1
//   // There is only one `1` in the bits of `mask`. The `1`'s position is same as
//   // the highest bit of n(mask = 2^(h-1) at first), and it will be shifted right
//   // iteratively to do `AND` operation with `n` to check `n_j` is odd or even,
//   // where n_j is defined below.
//   for (unsigned int i = 1, mask = 1 << (h-1) ; i <= h ; ++i, mask >>= 1) {
//     // Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) = n >> j
//     // (n_j = n when j = 0), x = floor(n_j / 2), then a = k^x now.
//     if (n & mask) {   // n_j is odd: x = (n_j - 1) / 2 => n_j = 2x + 1
//       a = k * a * a;  //   a = k^(n_j) = k^(2x+1) = k * (k^x)^2
//     } else {          // n_j is even: x = n_j / 2 => n_j = 2x
//       a = a * a;      //   a = k^(n_j) = k^(2x) = (k^x)^2
//     }
//   }

//   return a;
// }

// uint64_t pow9(unsigned int k, unsigned int n)
// {
//   // The position of the highest bit of n.
//   // So we need to loop `h` times to get the answer.
//   // Example: n = (Dec)50 = (Bin)00110010, then h = 6.
//   //                               ^ 6th bit from right side
//   unsigned int h = 0;
//   for (unsigned int i = n ; i ; ++h, i >>= 1);

//   uint64_t a = 1; // a = k^0 = 1
//   // There is only one `1` in the bits of `mask`. The `1`'s position is same as
//   // the highest bit of n(mask = 2^(h-1) at first), and it will be shifted right
//   // iteratively to do `AND` operation with `n` to check `n_j` is odd or even,
//   // where n_j is defined below.
//   for (unsigned int i = 1, mask = 1 << (h-1) ; i <= h ; ++i, mask >>= 1) {
//     // Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) = n >> j
//     // (n_j = n when j = 0), x = floor(n_j / 2), then a = k^x now.
//     a *= a; // a = (k^x)^2 = k^(2x)
//                     // n_j is even: x = n_j / 2 => n_j = 2x
//                     //   a = k^(n_j) = k^(2x)
//     if (n & mask) { // n_j is odd: x = (n_j - 1) / 2 => n_j = 2x + 1
//       a *= k;       //   a = k^(n_j) = k^(2x+1) = k * k^(2x)
//     }
//   }

//   return a;
// }

uint64_t pow9(unsigned int k, unsigned int n)
{
  // The position of the highest bit of n.
  // So we need to loop `h` times to get the answer.
  // Example: n = (Dec)50 = (Bin)00110010, then h = 6.
  //                               ^ 6th bit from right side
  unsigned int h = 0;
  for (unsigned int i = n ; i ; ++h, i >>= 1);

  uint64_t a = 1; // a = k^0 = 1
  // There is only one `1` in the bits of `mask`. The `1`'s position is same as
  // the highest bit of n(mask = 2^(h-1) at first), and it will be shifted right
  // iteratively to do `AND` operation with `n` to check `n_j` is odd or even,
  // where n_j is defined below.
  for (unsigned int mask = 1 << (h - 1) ; mask ; mask >>= 1) { // Run h times!
    // Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) = n >> j
    // (n_j = n when j = 0), x = floor(n_j / 2), then a = k^x now.
    a *= a; // a = (k^x)^2 = k^(2x)
    // n_j is even: x = n_j / 2 => n_j = 2x
    //   a = k^(n_j) = k^(2x)
    if (n & mask) { // n_j is odd: x = (n_j - 1) / 2 => n_j = 2x + 1
      a *= k;       //   a = k^(n_j) = k^(2x+1) = k * k^(2x)
    }
  }

  return a;
}