sdm0718-batch created by HiKat1 - https://repl.it/Jfa6/15
#!/usr/bin/env python
# coding:utf-8
#
# use Python3 !!!
#
import math
# import numpy as np
LEAERNING_RATE = 0.1
b1 = 8
b2 = 2
b3 = 0
w11 = (b2 + 1.0) / 10.0
w12 = -1.0 * (b3 + 1.0) / 10.0
w13 = (b1 + 1.0) / 10.0
w14 = (b3 + 1.0) / 10.0
w21 = -1.0 * (b1 + 1.0) / 10.0
w22 = -1.0 * (b2 + 1.0) / 10.0
grad_w11_accum = 0
grad_w12_accum = 0
grad_w13_accum = 0
grad_w14_accum = 0
grad_w21_accum = 0
grad_w22_accum = 0
# 順伝搬計算
def forward_backward(x1, x2, y):
# 順伝搬計算
a2 = x1 * w11 + x2 * w13
a3 = x1 * w12 + x2 * w14
a1 = math.tanh(a2) * w21 + math.tanh(a3) * w22
y_dash = math.tanh(a1)
# 逆伝搬計算
delta1 = -1.0 * (y - math.tanh(a1)) * (1.0 - pow(math.tanh(a1), 2))
delta2 = delta1 * w21 * (1.0 - pow(math.tanh(a2), 2))
delta3 = delta1 * w22 * (1.0 - pow(math.tanh(a3), 2))
# 勾配計算
grad_w11 = delta2 * x1
grad_w12 = delta3 * x1
grad_w13 = delta2 * x2
grad_w14 = delta3 * x2
grad_w21 = delta1 * math.tanh(a2)
grad_w22 = delta1 * math.tanh(a3)
global grad_w11_accum
global grad_w12_accum
global grad_w13_accum
global grad_w14_accum
global grad_w21_accum
global grad_w22_accum
# accumulate
grad_w11_accum += grad_w11
grad_w12_accum += grad_w12
grad_w13_accum += grad_w13
grad_w14_accum += grad_w14
grad_w21_accum += grad_w21
grad_w22_accum += grad_w22
print(y_dash)
return y_dash
def update_weight():
w11_updated = w11 + LEAERNING_RATE * -1.0 * grad_w11_accum * (1.0 / 4.0)
w12_updated = w12 + LEAERNING_RATE * -1.0 * grad_w12_accum * (1.0 / 4.0)
w13_updated = w13 + LEAERNING_RATE * -1.0 * grad_w13_accum * (1.0 / 4.0)
w14_updated = w14 + LEAERNING_RATE * -1.0 * grad_w14_accum * (1.0 / 4.0)
w21_updated = w21 + LEAERNING_RATE * -1.0 * grad_w21_accum * (1.0 / 4.0)
w22_updated = w22 + LEAERNING_RATE * -1.0 * grad_w22_accum * (1.0 / 4.0)
return [w11_updated, w12_updated, w13_updated, w14_updated, w21_updated, w22_updated]
print("Answer Q1 ================================")
print('w11 = {0} ,w13 = {1}, w21 = {2}'.format(w11, w13, w21))
print('w12 = {0} ,w14 = {1}, w22 = {2}'.format(w12, w14, w22))
forward_backward(1.0, 1.0, 1.0)
forward_backward(0, 1.0, -1.0)
forward_backward(0, 0, 1.0)
forward_backward(1.0, 0, -1.0)
print("batch SGD !!!!!!!!!=========")
temp = []
temp = update_weight()
print('w11 = {0} ,w13 = {1}, w21 = {2}'.format(temp[0], temp[2], temp[4]))
print('w12 = {0} ,w14 = {1}, w22 = {2}'.format(temp[1], temp[3], temp[5]))