HiKat
7/24/2017 - 12:04 PM

sdm0718-batch created by HiKat1 - https://repl.it/Jfa6/15

sdm0718-batch created by HiKat1 - https://repl.it/Jfa6/15

#!/usr/bin/env python
# coding:utf-8
#
# use Python3 !!!
#
import math
# import numpy as np

LEAERNING_RATE = 0.1

b1 = 8
b2 = 2
b3 = 0

w11 = (b2 + 1.0) / 10.0
w12 = -1.0 * (b3 + 1.0) / 10.0
w13 = (b1 + 1.0) / 10.0
w14 = (b3 + 1.0) / 10.0
w21 = -1.0 * (b1 + 1.0) / 10.0
w22 = -1.0 * (b2 + 1.0) / 10.0


grad_w11_accum = 0
grad_w12_accum = 0
grad_w13_accum = 0
grad_w14_accum = 0
grad_w21_accum = 0
grad_w22_accum = 0



# 順伝搬計算
def forward_backward(x1, x2, y):
	# 順伝搬計算
	a2 = x1 * w11 + x2 * w13
	a3 = x1 * w12 + x2 * w14
	a1 = math.tanh(a2) * w21 + math.tanh(a3) * w22
	y_dash = math.tanh(a1)
	
	# 逆伝搬計算
	delta1 = -1.0 * (y - math.tanh(a1)) * (1.0 - pow(math.tanh(a1), 2))
	delta2 = delta1 * w21 * (1.0 - pow(math.tanh(a2), 2))
	delta3 = delta1 * w22 * (1.0 - pow(math.tanh(a3), 2))

	# 勾配計算
	grad_w11 = delta2 * x1
	grad_w12 = delta3 * x1
	grad_w13 = delta2 * x2
	grad_w14 = delta3 * x2
	grad_w21 = delta1 * math.tanh(a2)
	grad_w22 = delta1 * math.tanh(a3)


	global grad_w11_accum
	global grad_w12_accum
	global grad_w13_accum
	global grad_w14_accum
	global grad_w21_accum
	global grad_w22_accum

	# accumulate
	grad_w11_accum += grad_w11
	grad_w12_accum += grad_w12
	grad_w13_accum += grad_w13
	grad_w14_accum += grad_w14 
	grad_w21_accum += grad_w21
	grad_w22_accum += grad_w22
	
	print(y_dash)
	return y_dash


def update_weight():
	w11_updated = w11 + LEAERNING_RATE * -1.0 * grad_w11_accum * (1.0 / 4.0)
	w12_updated = w12 + LEAERNING_RATE * -1.0 * grad_w12_accum * (1.0 / 4.0)
	w13_updated = w13 + LEAERNING_RATE * -1.0 * grad_w13_accum * (1.0 / 4.0)
	w14_updated = w14 + LEAERNING_RATE * -1.0 * grad_w14_accum * (1.0 / 4.0)
	w21_updated = w21 + LEAERNING_RATE * -1.0 * grad_w21_accum * (1.0 / 4.0)
	w22_updated = w22 + LEAERNING_RATE * -1.0 * grad_w22_accum * (1.0 / 4.0)

	return [w11_updated, w12_updated, w13_updated, w14_updated, w21_updated, w22_updated]

print("Answer Q1 ================================")
print('w11 = {0} ,w13 = {1}, w21 = {2}'.format(w11, w13, w21))
print('w12 = {0} ,w14 = {1}, w22 = {2}'.format(w12, w14, w22))

forward_backward(1.0, 1.0, 1.0)
forward_backward(0, 1.0, -1.0)
forward_backward(0, 0, 1.0)
forward_backward(1.0, 0, -1.0)
print("batch SGD !!!!!!!!!=========")
temp = []
temp = update_weight()
print('w11 = {0} ,w13 = {1}, w21 = {2}'.format(temp[0], temp[2], temp[4]))
print('w12 = {0} ,w14 = {1}, w22 = {2}'.format(temp[1], temp[3], temp[5]))