ChunMinChang
8/19/2017 - 3:20 AM

## Exponentiation by squaring

Exponentiation by squaring #math #recursion

``````#include "exp.h"
#include <cassert>  // assert
#include <ctime>    // std::clock
#include <cstdint>  // uint64_t
#include <iostream> // std::cout, std::cin, std::endl

typedef struct Time
{
double cpu;
double wall;
} Time;

typedef struct Result
{
uint64_t value;
Time time;
} Result;

template <typename Function, typename... Args>
Result calculate(Function aFunction, Args&&... aArgs)
{
std::clock_t cpuStart = std::clock();
auto wallStart = std::chrono::high_resolution_clock::now();

uint64_t r = aFunction(std::forward<Args>(aArgs)...);

std::clock_t cpuEnd = std::clock();
auto wallEnd = std::chrono::high_resolution_clock::now();
return Result { r, Time { 1000.0 * (cpuEnd - cpuStart) / CLOCKS_PER_SEC,
std::chrono::duration<double, std::milli>
(wallEnd - wallStart).count()
}
};
}

int main()
{
unsigned int k = 0, n = 0;
std::cout << "Enter k, n to get k^n: ";
std::cin >> k >> n;
Result p1 = calculate(pow1, k, n);
Result p2 = calculate(pow2, k, n);
Result p3 = calculate(pow3, k, n);
Result p4 = calculate(pow4, k, n);
Result p5 = calculate(pow5, k, n);
Result p6 = calculate(pow6, k, n);
Result p7 = calculate(pow7, k, n);
Result p8 = calculate(pow8, k, n);
Result p9 = calculate(pow9, k, n);
assert(p1.value == p2.value &&
p2.value == p3.value &&
p3.value == p4.value &&
p4.value == p5.value &&
p5.value == p6.value &&
p6.value == p7.value &&
p7.value == p8.value &&
p8.value == p9.value);
std::cout << k << "^" << n << " = " << p1.value << std::endl;

std::cout << "\nTime\n-----" << std::endl;
std::cout << "1 \tcpu: " << p1.time.cpu << "ms, wall: " << p1.time.wall << " ms" << std::endl;
std::cout << "2 \tcpu: " << p2.time.cpu << "ms, wall: " << p2.time.wall << " ms" << std::endl;
std::cout << "3 \tcpu: " << p3.time.cpu << "ms, wall: " << p3.time.wall << " ms" << std::endl;
std::cout << "4 \tcpu: " << p4.time.cpu << "ms, wall: " << p4.time.wall << " ms" << std::endl;
std::cout << "5 \tcpu: " << p5.time.cpu << "ms, wall: " << p5.time.wall << " ms" << std::endl;
std::cout << "6 \tcpu: " << p6.time.cpu << "ms, wall: " << p6.time.wall << " ms" << std::endl;
std::cout << "7 \tcpu: " << p7.time.cpu << "ms, wall: " << p7.time.wall << " ms" << std::endl;
std::cout << "8 \tcpu: " << p8.time.cpu << "ms, wall: " << p8.time.wall << " ms" << std::endl;
std::cout << "9 \tcpu: " << p9.time.cpu << "ms, wall: " << p9.time.wall << " ms" << std::endl;

return 0;
}
``````
``````CC=g++
CFLAGS=-Wall --std=c++11

SOURCES=test.cpp\
exp.cpp
OBJECTS=\$(SOURCES:.cpp=.o)

# Name of the executable program
EXECUTABLE=run_test

all: \$(EXECUTABLE)

# \$@ is same as \$(EXECUTABLE)
\$(EXECUTABLE): \$(OBJECTS)
\$(CC) \$(CFLAGS) \$(OBJECTS) -o \$@

# ".cpp.o" is a suffix rule telling make how to turn file.cpp into file.o
# for an arbitrary file.
#
# \$< is an automatic variable referencing the source file,
# file.cpp in the case of the suffix rule.
#
# \$@ is an automatic variable referencing the target file, file.o.
# file.o in the case of the suffix rule.
#
# Use -c to generatethe .o file
.cpp.o:
\$(CC) -c \$(CFLAGS) \$< -o \$@

clean:
rm *.o \$(EXECUTABLE)
``````
``````// Exponentiation by squaring
#ifndef EXP_BY_SQUARING
#define EXP_BY_SQUARING

#include <cstdint>  // uint64_t

uint64_t pow1(unsigned int k, unsigned int n);
uint64_t pow2(unsigned int k, unsigned int n);
uint64_t pow3(unsigned int k, unsigned int n);
uint64_t pow4(unsigned int k, unsigned int n);
uint64_t pow5(unsigned int k, unsigned int n);
uint64_t pow6(unsigned int k, unsigned int n);
uint64_t pow7(unsigned int k, unsigned int n);
uint64_t pow8(unsigned int k, unsigned int n);
uint64_t pow9(unsigned int k, unsigned int n);

#endif // EXP_BY_SQUARING``````
``````#include "exp.h"
#include <stack>

// Calculate the exponentiation by squaring:
//   k ^ n = (k ^ (n / 2)) ^ 2           if n is even
//        or k * (k ^ ((n - 1) / 2))^ 2  if n is odd
//   or
//   k ^ n = (k ^ 2) ^ (n / 2)           if n is even
//        or k * (k ^ 2) ^ ((n - 1) / 2) if n is odd
//
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring
uint64_t pow1(unsigned int k, unsigned int n)
{
if (!n) {
return 1;
}

if (n % 2) {
return k * pow1(k, (n - 1) / 2) * pow1(k, (n - 1) / 2);
} else {
return pow1(k, n / 2) * pow1(k, n / 2);
}
}

uint64_t pow2(unsigned int k, unsigned int n)
{
if (!n) {
return 1;
}

if (n % 2) {
return k * pow2(k * k, (n - 1) / 2);
} else {
return pow2(k * k, n / 2);
}
}

uint64_t pow3(unsigned int k, unsigned int n)
{
if (!n) {
return 1;
}

// uint64_t x = pow3(k, n / 2);
// x *= x;
uint64_t x = pow3(k * k, n / 2);
// so x = (k^2)^(n/2),      if n is even
//     or (k^2)^((n-1)/2),  if n is odd

if (n % 2) {
x = x * k; // so x = k * (k^2)^((n-1)/2)
}
return x;
}

uint64_t pow4(unsigned int k, unsigned int n)
{
if (!n) {
return 1;
}

uint64_t x = pow4(k * k, /*n / 2*/n >> 1);
return (/*n % 2*/ n & 1) ? x * k : x;
}

uint64_t pow5(unsigned int k, unsigned int n)
{
if (!n) {
return 1;
}

uint64_t r = 1; // The remaining part for the squaring.
while (n > 1) {
if (/*n % 2*/ n & 1) {
r *= k;
k *= k;
n = (n - 1) / 2;
} else {
k *= k;
n = n / 2;
}
}

return r * k;
}

uint64_t pow6(unsigned int k, unsigned int n)
{
// The `r` should be the remaining part for the squaring(in pow5).
// However, we notice that the `r * k` is duplicated in pow5. We will get
// the answer by `r * k` when n = 1. If we keep looping when n = 1,
// `r` is our answer. Nevertheless, we will waste time to do `k *= k`
// when n = 1.
uint64_t r = 1;

while (n) {
if (/*n % 2*/ n & 1) {
r *= k;
k *= k;
n = (n - 1) / 2;
} else {
k *= k;
n = n / 2;
}
}

return r;
}

uint64_t pow7(unsigned int k, unsigned int n)
{
uint64_t r = 1;
while (n) {
if (/*n % 2*/ n & 1) {
r *= k;
}
k *= k;
/*n /= 2*/n >>= 1;
}

return r;
}

// uint64_t pow8(unsigned int k, unsigned int n)
// {
//   std::stack<unsigned int> s;
//   while(n) {
//     s.push(n);
//     n >>= 1;
//   }

//   uint64_t a = 1; // a = k^0 = 1
//   while (!s.empty()) {
//     unsigned int x = s.top(); s.pop();
//     // Let y = floor(x/2), y = x/2 if x is even, y = (x-1)/2 if x is odd.
//     // then a = k^y now.
//     if (x % 2) {      // x is odd:
//       a = k * a * a;  //   a = k^x = k^(2y+1) = k * (k^y)^2
//     } else {          // x is even:
//       a = a * a;      //   a = k^x = k^(2y) = (k^y)^2
//     }
//   }

//   return a;
// }

uint64_t pow8(unsigned int k, unsigned int n)
{
std::stack<unsigned int> s;
while(n) {
s.push(n);
n >>= 1;
}

uint64_t a = 1; // a = k^0 = 1
while (!s.empty()) {
unsigned int x = s.top(); s.pop();
// Let y = floor(x/2), y = x/2 if x is even, y = (x-1)/2 if x is odd.
// then a = k^y now.
a *= a; // a = (k^y)^2 = k^(2y)
// x is even:
//   a = k^x = k^(2y)
if (x % 2) {  // x is odd:
a *= k;     //   a = k^x = k^(2y+1) = k * k^(2y)
}
}

return a;
}

// uint64_t pow9(unsigned int k, unsigned int n)
// {
//   // The position of the highest bit of n.
//   // So we need to loop `h` times to get the answer.
//   // Example: n = (Dec)50 = (Bin)00110010, then h = 6.
//   //                               ^ 6th bit from right side
//   unsigned int h = 0;
//   for (unsigned int i = n ; i ; ++h, i >>= 1);

//   uint64_t a = 1; // a = k^0 = 1
//   // There is only one `1` in the bits of `mask`. The `1`'s position is same as
//   // the highest bit of n(mask = 2^(h-1) at first), and it will be shifted right
//   // iteratively to do `AND` operation with `n` to check `n_j` is odd or even,
//   // where n_j is defined below.
//   for (unsigned int i = 1, mask = 1 << (h-1) ; i <= h ; ++i, mask >>= 1) {
//     // Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) = n >> j
//     // (n_j = n when j = 0), x = floor(n_j / 2), then a = k^x now.
//     if (n & mask) {   // n_j is odd: x = (n_j - 1) / 2 => n_j = 2x + 1
//       a = k * a * a;  //   a = k^(n_j) = k^(2x+1) = k * (k^x)^2
//     } else {          // n_j is even: x = n_j / 2 => n_j = 2x
//       a = a * a;      //   a = k^(n_j) = k^(2x) = (k^x)^2
//     }
//   }

//   return a;
// }

// uint64_t pow9(unsigned int k, unsigned int n)
// {
//   // The position of the highest bit of n.
//   // So we need to loop `h` times to get the answer.
//   // Example: n = (Dec)50 = (Bin)00110010, then h = 6.
//   //                               ^ 6th bit from right side
//   unsigned int h = 0;
//   for (unsigned int i = n ; i ; ++h, i >>= 1);

//   uint64_t a = 1; // a = k^0 = 1
//   // There is only one `1` in the bits of `mask`. The `1`'s position is same as
//   // the highest bit of n(mask = 2^(h-1) at first), and it will be shifted right
//   // iteratively to do `AND` operation with `n` to check `n_j` is odd or even,
//   // where n_j is defined below.
//   for (unsigned int i = 1, mask = 1 << (h-1) ; i <= h ; ++i, mask >>= 1) {
//     // Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) = n >> j
//     // (n_j = n when j = 0), x = floor(n_j / 2), then a = k^x now.
//     a *= a; // a = (k^x)^2 = k^(2x)
//                     // n_j is even: x = n_j / 2 => n_j = 2x
//                     //   a = k^(n_j) = k^(2x)
//     if (n & mask) { // n_j is odd: x = (n_j - 1) / 2 => n_j = 2x + 1
//       a *= k;       //   a = k^(n_j) = k^(2x+1) = k * k^(2x)
//     }
//   }

//   return a;
// }

uint64_t pow9(unsigned int k, unsigned int n)
{
// The position of the highest bit of n.
// So we need to loop `h` times to get the answer.
// Example: n = (Dec)50 = (Bin)00110010, then h = 6.
//                               ^ 6th bit from right side
unsigned int h = 0;
for (unsigned int i = n ; i ; ++h, i >>= 1);

uint64_t a = 1; // a = k^0 = 1
// There is only one `1` in the bits of `mask`. The `1`'s position is same as
// the highest bit of n(mask = 2^(h-1) at first), and it will be shifted right
// iteratively to do `AND` operation with `n` to check `n_j` is odd or even,
// where n_j is defined below.
for (unsigned int mask = 1 << (h - 1) ; mask ; mask >>= 1) { // Run h times!
// Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) = n >> j
// (n_j = n when j = 0), x = floor(n_j / 2), then a = k^x now.
a *= a; // a = (k^x)^2 = k^(2x)
// n_j is even: x = n_j / 2 => n_j = 2x
//   a = k^(n_j) = k^(2x)
if (n & mask) { // n_j is odd: x = (n_j - 1) / 2 => n_j = 2x + 1
a *= k;       //   a = k^(n_j) = k^(2x+1) = k * k^(2x)
}
}

return a;
}``````