How model.trainable = False works in keras (GAN model)
# coding: utf8
## based on this article: http://qiita.com/mokemokechicken/items/937a82cfdc31e9a6ca12
import numpy as np
from keras.models import Sequential
from keras.engine.topology import Input, Container
from keras.engine.training import Model
from keras.layers.core import Dense
def all_weights(m):
return [list(w.reshape((-1))) for w in m.get_weights()]
def random_fit(m):
x1 = np.random.random(10).reshape((5, 2))
y1 = np.random.random(10).reshape((5, 2))
m.fit(x1, y1, verbose=False)
np.random.seed(100)
# Discriminator model
x = in_x = Input((2, ))
x = Dense(1)(x)
x = Dense(2)(x)
model_D = Model(in_x, x)
# Compile D
model_D.compile(optimizer="sgd", loss="mse")
# Generator model
x = in_x = Input((2, ))
x = Dense(1)(x)
x = Dense(2)(x)
model_G = Model(in_x, x)
# Adversarial model
model_A = Sequential()
model_A.add(model_G)
model_A.add(model_D)
# Compile A
model_D.trainable = False # set D in A "trainable=False"
model_A.compile(optimizer="sgd", loss="mse")
# Watch which weights are updated by model.fit
print("Initial Weights")
print("G: %s" % all_weights(model_G))
print("D: %s" % all_weights(model_D))
print("A : %s" % all_weights(model_A))
random_fit(model_D)
print
print("after training D --- D and D in A changed")
print("G: %s" % all_weights(model_G))
print("D: %s" % all_weights(model_D))
print("A : %s" % all_weights(model_A))
random_fit(model_A)
print
print("after training A --- D didn't changed!")
print("G: %s" % all_weights(model_G))
print("D: %s" % all_weights(model_D))
print("A : %s" % all_weights(model_A))
random_fit(model_D)
print
print("after training D")
print("G: %s" % all_weights(model_G))
print("D: %s" % all_weights(model_D))
print("A : %s" % all_weights(model_A))
random_fit(model_A)
print
print("after training A")
print("G: %s" % all_weights(model_G))
print("D: %s" % all_weights(model_D))
print("A : %s" % all_weights(model_A))
# Initial Weights
# G: [[-0.27850878, -0.52411258], [0.0], [0.94569027, 0.83747566], [0.0, 0.0]]
# D: [[0.50677133, -0.43742394], [0.0], [1.2930039, -1.2365541], [0.0, 0.0]]
# A : [[-0.27850878, -0.52411258], [0.0], [0.94569027, 0.83747566], [0.0, 0.0], [0.50677133, -0.43742394], [0.0], [1.2930039, -1.2365541], [0.0, 0.0]]
# after training D --- D and D in A changed
# G: [[-0.27850878, -0.52411258], [0.0], [0.94569027, 0.83747566], [0.0, 0.0]]
# D: [[0.49537802, -0.4082337], [0.0034225769], [1.2876366, -1.2274913], [0.047490694, 0.046951186]]
# A : [[-0.27850878, -0.52411258], [0.0], [0.94569027, 0.83747566], [0.0, 0.0], [0.49537802, -0.4082337], [0.0034225769], [1.2876366, -1.2274913], [0.047490694, 0.046951186]]
# after training A --- D didn't changed!
# G: [[-0.27628738, -0.52191412], [0.0054477928], [0.93868071, 0.84325212], [0.021782838, -0.017950913]]
# D: [[0.49537802, -0.4082337], [0.0034225769], [1.2876366, -1.2274913], [0.047490694, 0.046951186]]
# A : [[-0.27628738, -0.52191412], [0.0054477928], [0.93868071, 0.84325212], [0.021782838, -0.017950913], [0.49537802, -0.4082337], [0.0034225769], [1.2876366, -1.2274913], [0.047490694, 0.046951186]]
# after training D
# G: [[-0.27628738, -0.52191412], [0.0054477928], [0.93868071, 0.84325212], [0.021782838, -0.017950913]]
# D: [[0.45315021, -0.42550534], [-0.069068611], [1.2836961, -1.222793], [0.054722041, 0.11372232]]
# A : [[-0.27628738, -0.52191412], [0.0054477928], [0.93868071, 0.84325212], [0.021782838, -0.017950913], [0.45315021, -0.42550534], [-0.069068611], [1.2836961, -1.222793], [0.054722041, 0.11372232]]
# after training A
# G: [[-0.27531064, -0.52016109], [0.0084079718], [0.93036431, 0.85106117], [0.042959597, -0.037835769]]
# D: [[0.45315021, -0.42550534], [-0.069068611], [1.2836961, -1.222793], [0.054722041, 0.11372232]]
# A : [[-0.27531064, -0.52016109], [0.0084079718], [0.93036431, 0.85106117], [0.042959597, -0.037835769], [0.45315021, -0.42550534], [-0.069068611], [1.2836961, -1.222793], [0.054722041, 0.11372232]]