Skip to content
Snippets Groups Projects
Commit 5b5f60ae authored by Antti Mäkelä's avatar Antti Mäkelä
Browse files

update attention model

parent 1d3cdf77
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,10 @@ an otherwise similar model except with a self-attention layer. ...@@ -24,7 +24,10 @@ an otherwise similar model except with a self-attention layer.
![Self-Attention Generative Adversarial Networks (SAGAN)](https://arxiv.org/abs/1805.08318) ![Self-Attention Generative Adversarial Networks (SAGAN)](https://arxiv.org/abs/1805.08318)
explains how to use self-attention in a GAN environment. This model adapts the layer to Adversarial Autoencoder. explains how to use self-attention in a GAN environment. This model adapts the layer to Adversarial Autoencoder.
![The latent space with self-attention and color](attention.mp4) ![The latent space with self-attention and color](attention_comparison.mp4)
The animation on the left uses attention while the right one has attention disabled. The network has learned to
utilise attention to control skin tone. Note that the network was trained with both monochrome and color images.
## Adversarial Upscaler ## Adversarial Upscaler
......
...@@ -54,7 +54,7 @@ MODEL_DIR="TF_MODEL_CURRENT/{}".format(MODEL_NAME) ...@@ -54,7 +54,7 @@ MODEL_DIR="TF_MODEL_CURRENT/{}".format(MODEL_NAME)
#activation = tf.nn.relu #activation = tf.nn.relu
activation = lambda x: x * tf.nn.sigmoid(x) activation = tf.nn.swish
#https://arxiv.org/pdf/1707.05776.pdf #https://arxiv.org/pdf/1707.05776.pdf
...@@ -150,16 +150,16 @@ def attention(x, name, attention_summary=False, get_attention=False): ...@@ -150,16 +150,16 @@ def attention(x, name, attention_summary=False, get_attention=False):
def conv_encode(x, name="a2l"): def conv_encode(x, name="a2l"):
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
x = tf.layers.conv2d(x, 16, 5, padding="same", activation=activation) x = tf.layers.conv2d(x, 64, 5, padding="same", activation=activation)
x = tf.layers.max_pooling2d(x, 2, 2) x = tf.layers.max_pooling2d(x, 2, 2)
x = tf.layers.conv2d(x, 32, 5, padding="same", activation=activation) x = tf.layers.conv2d(x, 96, 5, padding="same", activation=activation)
x = tf.layers.max_pooling2d(x, 2, 2) x = tf.layers.max_pooling2d(x, 2, 2)
x = tf.layers.conv2d(x, 64, 3, padding="same", activation=activation) x = tf.layers.conv2d(x, 128, 3, padding="same", activation=activation)
x = tf.layers.max_pooling2d(x, 2, 2) x = tf.layers.max_pooling2d(x, 2, 2)
x = tf.layers.conv2d(x, 128, 3, padding="same", activation=activation) x = tf.layers.conv2d(x, 160, 3, padding="same", activation=activation)
x = residual_layer(x, layers=3, name="{}_res".format(name)) x = residual_layer(x, layers=3, name="{}_res".format(name))
print("Conv Encoder conv {}".format(x.shape)) print("Conv Encoder conv {}".format(x.shape))
...@@ -189,18 +189,18 @@ def conv_decode(latent, input_size, name="l2a", get_attention=False): ...@@ -189,18 +189,18 @@ def conv_decode(latent, input_size, name="l2a", get_attention=False):
#256 -> 192 #256 -> 192
#256 -> 128 #256 -> 128
#256 -> 96 #256 -> 96
x = tf.layers.conv2d_transpose(x, 128, 3, strides=2, padding="same", activation=activation) x = tf.layers.conv2d_transpose(x, 256, 3, strides=2, padding="same", activation=activation)
x = residual_layer(x, layers=2, name="{}_res1".format(name)) x = residual_layer(x, layers=3, name="{}_res1".format(name))
x = tf.layers.conv2d_transpose(x, 64, 3, strides=2, padding="same", activation=activation) x = tf.layers.conv2d_transpose(x, 192, 3, strides=2, padding="same", activation=activation)
x = residual_layer(x, layers=2, name="{}_res2".format(name)) x = residual_layer(x, layers=3, name="{}_res2".format(name))
x = tf.layers.conv2d_transpose(x, 32, 3, strides=2, padding="same", activation=activation) x = tf.layers.conv2d_transpose(x, 128, 3, strides=2, padding="same", activation=activation)
x = residual_layer(x, layers=1, name="{}_res3".format(name)) x = residual_layer(x, layers=1, name="{}_res3".format(name))
x, o = attention(x, "attention", True, True) x, o = attention(x, "attention", True, True)
x = residual_layer(x, layers=1, name="{}_res4".format(name)) x = residual_layer(x, layers=1, name="{}_res4".format(name))
x = tf.layers.conv2d_transpose(x, 16, 3, strides=2, padding="same", activation=activation) x = tf.layers.conv2d_transpose(x, 96, 3, strides=2, padding="same", activation=activation)
x = tf.nn.tanh(tf.layers.conv2d(x, INPUT_CHANNELS, 3, padding="same")) * 2 x = tf.nn.tanh(tf.layers.conv2d(x, INPUT_CHANNELS, 3, padding="same")) * 2
if get_attention: if get_attention:
...@@ -387,7 +387,7 @@ def model_fn(features, labels, mode): ...@@ -387,7 +387,7 @@ def model_fn(features, labels, mode):
tf.summary.scalar("{}_loss".format(name), loss) tf.summary.scalar("{}_loss".format(name), loss)
tf.summary.scalar("{}_generator_loss".format(name), gen_loss) tf.summary.scalar("{}_generator_loss".format(name), gen_loss)
tf.summary.scalar("{}_generator_dist_loss".format(name), gen_loss) tf.summary.scalar("{}_generator_dist_loss".format(name), gen_loss)
return loss, gen_loss + dist_loss * 2.0 return loss, gen_loss + dist_loss
loss_c_a = cycle_loss(a, cycle_a, "a") loss_c_a = cycle_loss(a, cycle_a, "a")
...@@ -396,7 +396,7 @@ def model_fn(features, labels, mode): ...@@ -396,7 +396,7 @@ def model_fn(features, labels, mode):
cyc_loss = loss_c_a * 1.5 cyc_loss = loss_c_a
loss_d = loss_d_a + loss_d_a_im loss_d = loss_d_a + loss_d_a_im
loss_a2l = cyc_loss + loss_g_a2l loss_a2l = cyc_loss + loss_g_a2l
...@@ -428,21 +428,21 @@ def model_fn(features, labels, mode): ...@@ -428,21 +428,21 @@ def model_fn(features, labels, mode):
with tf.variable_scope("regularization", reuse=tf.AUTO_REUSE): with tf.variable_scope("regularization", reuse=tf.AUTO_REUSE):
#A2L reg loss #A2L reg loss
g_a2l_l2_losses = [ tf.nn.l2_loss(v) for v in g_a2l_vars if "kernel" in v.name ] g_a2l_l2_losses = [ tf.nn.l2_loss(v) for v in g_a2l_vars if "kernel" in v.name ]
g_a2l_reg_loss = tf.add_n(g_a2l_l2_losses) / len(g_a2l_l2_losses) * 0.00001 g_a2l_reg_loss = tf.add_n(g_a2l_l2_losses) / len(g_a2l_l2_losses) * 0.000001
tf.summary.scalar("a2l_regularization_loss", g_a2l_reg_loss) tf.summary.scalar("a2l_regularization_loss", g_a2l_reg_loss)
#L2A reg loss #L2A reg loss
g_l2a_l2_losses = [ tf.nn.l2_loss(v) for v in g_l2a_vars if "kernel" in v.name ] g_l2a_l2_losses = [ tf.nn.l2_loss(v) for v in g_l2a_vars if "kernel" in v.name ]
g_l2a_reg_loss = tf.add_n(g_l2a_l2_losses) / len(g_l2a_l2_losses) * 0.00001 g_l2a_reg_loss = tf.add_n(g_l2a_l2_losses) / len(g_l2a_l2_losses) * 0.000001
tf.summary.scalar("l2a_regularization_loss", g_l2a_reg_loss) tf.summary.scalar("l2a_regularization_loss", g_l2a_reg_loss)
#Disc reg loss #Disc reg loss
d_l2_losses = [ tf.nn.l2_loss(v) for v in d_vars if "kernel" in v.name ] d_l2_losses = [ tf.nn.l2_loss(v) for v in d_vars if "kernel" in v.name ]
d_reg_loss = tf.add_n(d_l2_losses) / len(d_l2_losses) * 0.001 d_reg_loss = tf.add_n(d_l2_losses) / len(d_l2_losses) * 0.00003
tf.summary.scalar("discriminator_regularization_loss", d_reg_loss) tf.summary.scalar("discriminator_regularization_loss", d_reg_loss)
loss_a2l += g_a2l_reg_loss loss_a2l += g_a2l_reg_loss
...@@ -570,6 +570,7 @@ def main(inputs): ...@@ -570,6 +570,7 @@ def main(inputs):
tot_output = [] tot_output = []
get_attention = False get_attention = False
save_attention = True
np.random.seed(3) np.random.seed(3)
random.seed(4) random.seed(4)
...@@ -587,7 +588,7 @@ def main(inputs): ...@@ -587,7 +588,7 @@ def main(inputs):
extra[0] = math.sin(rnd*2*3.1415/4.0)*0 extra[0] = math.sin(rnd*2*3.1415/4.0)*0
extra[1] = math.cos(-rnd*2*3.1415/4.0)*0 extra[1] = math.cos(-rnd*2*3.1415/4.0)*0
extra[2] = math.cos(rnd*2*3.1415/4.0)*0 extra[2] = math.cos(rnd*2*3.1415/4.0)*0
rnd += 0.1 rnd += 0.01
lat_mod = np.array(center, dtype=np.float32) lat_mod = np.array(center, dtype=np.float32)
...@@ -612,8 +613,8 @@ def main(inputs): ...@@ -612,8 +613,8 @@ def main(inputs):
lat_mod = tf.reshape(prv[2], [1, LATENT_SIZE]) lat_mod = tf.reshape(prv[2], [1, LATENT_SIZE])
if get_attention: if get_attention:
return {"input": prv[0]} return {"input": prv[0]}
else: else: #When latent_add is an input, the input image has no effect
return {"input": prv[0], "lat_mod": lat_mod} return {"input": prv[0], "latent_add": lat_mod}
prd = estimator.predict(predict_input_fn_2) prd = estimator.predict(predict_input_fn_2)
if get_attention: if get_attention:
...@@ -634,21 +635,43 @@ def main(inputs): ...@@ -634,21 +635,43 @@ def main(inputs):
if maximum > 0: if maximum > 0:
lp = lp / maximum lp = lp / maximum
im = Image.fromarray(np.uint8(lp*255), "L") im = Image.fromarray(np.uint8(lp*255), "L")
im.save("att/attention_{}.png".format(str(x).zfill(4))) im.save("att/attention_{}.png".format(str(x).zfill(4)))
else: else:
for i in range(0, 40): for i in range(0, 400):
p = next(prd) p = next(prd)
lp = p["output"] lp = p["output"]
lp = np.reshape(lp, [INPUT_SIZE, INPUT_SIZE, INPUT_CHANNELS]) lp = np.reshape(lp, [INPUT_SIZE, INPUT_SIZE, INPUT_CHANNELS])
lp = np.clip(lp, -1, 1) lp = np.clip(lp, -1, 1)
im = Image.fromarray(np.uint8(lp*127.5+127.5), "RGB") im = Image.fromarray(np.uint8(lp*127.5+127.5), "RGB")
im.save("o/output_{}.png".format(str(i).zfill(4))) im.save("o/output_{}.png".format(str(i).zfill(4)))
if save_attention:
att = p["attention"]
att_size = att.shape[-2]
att_channels = att.shape[-1]
lp = np.reshape(att[...], [att_size, att_size, att_channels])
lp = lp - np.amin(lp, keepdims=True, axis=(0,1))
maximum = np.amax(lp, keepdims=True, axis=(0,1))
lp = lp / maximum
lp = np.nan_to_num(lp)
lp = np.mean(lp, axis=-1)
lp = lp - np.amin(lp)
lp = lp / np.amax(lp)
lp = np.nan_to_num(lp)
lp = np.uint8(lp * 254)
im = Image.fromarray(lp, "L")
im.save("att/attention_{}.png".format(str(i).zfill(4)))
im2 = Image.fromarray(np.uint8(cur*127.5+127.5), "RGB") # For saving the input image
im2.save("o/orig_{}.png".format(str(i).zfill(4))) #im2 = Image.fromarray(np.uint8(cur*127.5+127.5), "RGB")
#im2.save("o/orig_{}.png".format(str(i).zfill(4)))
else: else:
if len(inputs) > 0: if len(inputs) > 0:
......
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment