Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
G
generative-adversarial-model-experiments
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
data-analysis-and-ai
generative-adversarial-model-experiments
Commits
5b5f60ae
Commit
5b5f60ae
authored
6 years ago
by
Antti Mäkelä
Browse files
Options
Downloads
Patches
Plain Diff
update attention model
parent
1d3cdf77
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
README.md
+4
-1
4 additions, 1 deletion
README.md
aa_attention.py
+45
-22
45 additions, 22 deletions
aa_attention.py
attention_comparison.mp4
+0
-0
0 additions, 0 deletions
attention_comparison.mp4
with
49 additions
and
23 deletions
README.md
+
4
−
1
View file @
5b5f60ae
...
@@ -24,7 +24,10 @@ an otherwise similar model except with a self-attention layer.
...
@@ -24,7 +24,10 @@ an otherwise similar model except with a self-attention layer.


explains how to use self-attention in a GAN environment. This model adapts the layer to Adversarial Autoencoder.
explains how to use self-attention in a GAN environment. This model adapts the layer to Adversarial Autoencoder.


The animation on the left uses attention while the right one has attention disabled. The network has learned to
utilise attention to control skin tone. Note that the network was trained with both monochrome and color images.
## Adversarial Upscaler
## Adversarial Upscaler
...
...
This diff is collapsed.
Click to expand it.
aa_attention.py
+
45
−
22
View file @
5b5f60ae
...
@@ -54,7 +54,7 @@ MODEL_DIR="TF_MODEL_CURRENT/{}".format(MODEL_NAME)
...
@@ -54,7 +54,7 @@ MODEL_DIR="TF_MODEL_CURRENT/{}".format(MODEL_NAME)
#activation = tf.nn.relu
#activation = tf.nn.relu
activation
=
lambda
x
:
x
*
tf
.
nn
.
sigmoid
(
x
)
activation
=
tf
.
nn
.
swish
#https://arxiv.org/pdf/1707.05776.pdf
#https://arxiv.org/pdf/1707.05776.pdf
...
@@ -150,16 +150,16 @@ def attention(x, name, attention_summary=False, get_attention=False):
...
@@ -150,16 +150,16 @@ def attention(x, name, attention_summary=False, get_attention=False):
def
conv_encode
(
x
,
name
=
"
a2l
"
):
def
conv_encode
(
x
,
name
=
"
a2l
"
):
with
tf
.
variable_scope
(
name
,
reuse
=
tf
.
AUTO_REUSE
):
with
tf
.
variable_scope
(
name
,
reuse
=
tf
.
AUTO_REUSE
):
x
=
tf
.
layers
.
conv2d
(
x
,
1
6
,
5
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d
(
x
,
6
4
,
5
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
max_pooling2d
(
x
,
2
,
2
)
x
=
tf
.
layers
.
max_pooling2d
(
x
,
2
,
2
)
x
=
tf
.
layers
.
conv2d
(
x
,
32
,
5
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d
(
x
,
96
,
5
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
max_pooling2d
(
x
,
2
,
2
)
x
=
tf
.
layers
.
max_pooling2d
(
x
,
2
,
2
)
x
=
tf
.
layers
.
conv2d
(
x
,
64
,
3
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d
(
x
,
128
,
3
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
max_pooling2d
(
x
,
2
,
2
)
x
=
tf
.
layers
.
max_pooling2d
(
x
,
2
,
2
)
x
=
tf
.
layers
.
conv2d
(
x
,
1
28
,
3
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d
(
x
,
1
60
,
3
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
residual_layer
(
x
,
layers
=
3
,
name
=
"
{}_res
"
.
format
(
name
))
x
=
residual_layer
(
x
,
layers
=
3
,
name
=
"
{}_res
"
.
format
(
name
))
print
(
"
Conv Encoder conv {}
"
.
format
(
x
.
shape
))
print
(
"
Conv Encoder conv {}
"
.
format
(
x
.
shape
))
...
@@ -189,18 +189,18 @@ def conv_decode(latent, input_size, name="l2a", get_attention=False):
...
@@ -189,18 +189,18 @@ def conv_decode(latent, input_size, name="l2a", get_attention=False):
#256 -> 192
#256 -> 192
#256 -> 128
#256 -> 128
#256 -> 96
#256 -> 96
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
128
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
256
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
residual_layer
(
x
,
layers
=
2
,
name
=
"
{}_res1
"
.
format
(
name
))
x
=
residual_layer
(
x
,
layers
=
3
,
name
=
"
{}_res1
"
.
format
(
name
))
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
64
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
192
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
residual_layer
(
x
,
layers
=
2
,
name
=
"
{}_res2
"
.
format
(
name
))
x
=
residual_layer
(
x
,
layers
=
3
,
name
=
"
{}_res2
"
.
format
(
name
))
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
32
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
128
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
residual_layer
(
x
,
layers
=
1
,
name
=
"
{}_res3
"
.
format
(
name
))
x
=
residual_layer
(
x
,
layers
=
1
,
name
=
"
{}_res3
"
.
format
(
name
))
x
,
o
=
attention
(
x
,
"
attention
"
,
True
,
True
)
x
,
o
=
attention
(
x
,
"
attention
"
,
True
,
True
)
x
=
residual_layer
(
x
,
layers
=
1
,
name
=
"
{}_res4
"
.
format
(
name
))
x
=
residual_layer
(
x
,
layers
=
1
,
name
=
"
{}_res4
"
.
format
(
name
))
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
1
6
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
layers
.
conv2d_transpose
(
x
,
9
6
,
3
,
strides
=
2
,
padding
=
"
same
"
,
activation
=
activation
)
x
=
tf
.
nn
.
tanh
(
tf
.
layers
.
conv2d
(
x
,
INPUT_CHANNELS
,
3
,
padding
=
"
same
"
))
*
2
x
=
tf
.
nn
.
tanh
(
tf
.
layers
.
conv2d
(
x
,
INPUT_CHANNELS
,
3
,
padding
=
"
same
"
))
*
2
if
get_attention
:
if
get_attention
:
...
@@ -387,7 +387,7 @@ def model_fn(features, labels, mode):
...
@@ -387,7 +387,7 @@ def model_fn(features, labels, mode):
tf
.
summary
.
scalar
(
"
{}_loss
"
.
format
(
name
),
loss
)
tf
.
summary
.
scalar
(
"
{}_loss
"
.
format
(
name
),
loss
)
tf
.
summary
.
scalar
(
"
{}_generator_loss
"
.
format
(
name
),
gen_loss
)
tf
.
summary
.
scalar
(
"
{}_generator_loss
"
.
format
(
name
),
gen_loss
)
tf
.
summary
.
scalar
(
"
{}_generator_dist_loss
"
.
format
(
name
),
gen_loss
)
tf
.
summary
.
scalar
(
"
{}_generator_dist_loss
"
.
format
(
name
),
gen_loss
)
return
loss
,
gen_loss
+
dist_loss
*
2.0
return
loss
,
gen_loss
+
dist_loss
loss_c_a
=
cycle_loss
(
a
,
cycle_a
,
"
a
"
)
loss_c_a
=
cycle_loss
(
a
,
cycle_a
,
"
a
"
)
...
@@ -396,7 +396,7 @@ def model_fn(features, labels, mode):
...
@@ -396,7 +396,7 @@ def model_fn(features, labels, mode):
cyc_loss
=
loss_c_a
*
1.5
cyc_loss
=
loss_c_a
loss_d
=
loss_d_a
+
loss_d_a_im
loss_d
=
loss_d_a
+
loss_d_a_im
loss_a2l
=
cyc_loss
+
loss_g_a2l
loss_a2l
=
cyc_loss
+
loss_g_a2l
...
@@ -428,21 +428,21 @@ def model_fn(features, labels, mode):
...
@@ -428,21 +428,21 @@ def model_fn(features, labels, mode):
with
tf
.
variable_scope
(
"
regularization
"
,
reuse
=
tf
.
AUTO_REUSE
):
with
tf
.
variable_scope
(
"
regularization
"
,
reuse
=
tf
.
AUTO_REUSE
):
#A2L reg loss
#A2L reg loss
g_a2l_l2_losses
=
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
g_a2l_vars
if
"
kernel
"
in
v
.
name
]
g_a2l_l2_losses
=
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
g_a2l_vars
if
"
kernel
"
in
v
.
name
]
g_a2l_reg_loss
=
tf
.
add_n
(
g_a2l_l2_losses
)
/
len
(
g_a2l_l2_losses
)
*
0.00001
g_a2l_reg_loss
=
tf
.
add_n
(
g_a2l_l2_losses
)
/
len
(
g_a2l_l2_losses
)
*
0.0000
0
1
tf
.
summary
.
scalar
(
"
a2l_regularization_loss
"
,
g_a2l_reg_loss
)
tf
.
summary
.
scalar
(
"
a2l_regularization_loss
"
,
g_a2l_reg_loss
)
#L2A reg loss
#L2A reg loss
g_l2a_l2_losses
=
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
g_l2a_vars
if
"
kernel
"
in
v
.
name
]
g_l2a_l2_losses
=
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
g_l2a_vars
if
"
kernel
"
in
v
.
name
]
g_l2a_reg_loss
=
tf
.
add_n
(
g_l2a_l2_losses
)
/
len
(
g_l2a_l2_losses
)
*
0.00001
g_l2a_reg_loss
=
tf
.
add_n
(
g_l2a_l2_losses
)
/
len
(
g_l2a_l2_losses
)
*
0.0000
0
1
tf
.
summary
.
scalar
(
"
l2a_regularization_loss
"
,
g_l2a_reg_loss
)
tf
.
summary
.
scalar
(
"
l2a_regularization_loss
"
,
g_l2a_reg_loss
)
#Disc reg loss
#Disc reg loss
d_l2_losses
=
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
d_vars
if
"
kernel
"
in
v
.
name
]
d_l2_losses
=
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
d_vars
if
"
kernel
"
in
v
.
name
]
d_reg_loss
=
tf
.
add_n
(
d_l2_losses
)
/
len
(
d_l2_losses
)
*
0.00
1
d_reg_loss
=
tf
.
add_n
(
d_l2_losses
)
/
len
(
d_l2_losses
)
*
0.00
003
tf
.
summary
.
scalar
(
"
discriminator_regularization_loss
"
,
d_reg_loss
)
tf
.
summary
.
scalar
(
"
discriminator_regularization_loss
"
,
d_reg_loss
)
loss_a2l
+=
g_a2l_reg_loss
loss_a2l
+=
g_a2l_reg_loss
...
@@ -570,6 +570,7 @@ def main(inputs):
...
@@ -570,6 +570,7 @@ def main(inputs):
tot_output
=
[]
tot_output
=
[]
get_attention
=
False
get_attention
=
False
save_attention
=
True
np
.
random
.
seed
(
3
)
np
.
random
.
seed
(
3
)
random
.
seed
(
4
)
random
.
seed
(
4
)
...
@@ -587,7 +588,7 @@ def main(inputs):
...
@@ -587,7 +588,7 @@ def main(inputs):
extra
[
0
]
=
math
.
sin
(
rnd
*
2
*
3.1415
/
4.0
)
*
0
extra
[
0
]
=
math
.
sin
(
rnd
*
2
*
3.1415
/
4.0
)
*
0
extra
[
1
]
=
math
.
cos
(
-
rnd
*
2
*
3.1415
/
4.0
)
*
0
extra
[
1
]
=
math
.
cos
(
-
rnd
*
2
*
3.1415
/
4.0
)
*
0
extra
[
2
]
=
math
.
cos
(
rnd
*
2
*
3.1415
/
4.0
)
*
0
extra
[
2
]
=
math
.
cos
(
rnd
*
2
*
3.1415
/
4.0
)
*
0
rnd
+=
0.1
rnd
+=
0.
0
1
lat_mod
=
np
.
array
(
center
,
dtype
=
np
.
float32
)
lat_mod
=
np
.
array
(
center
,
dtype
=
np
.
float32
)
...
@@ -612,8 +613,8 @@ def main(inputs):
...
@@ -612,8 +613,8 @@ def main(inputs):
lat_mod
=
tf
.
reshape
(
prv
[
2
],
[
1
,
LATENT_SIZE
])
lat_mod
=
tf
.
reshape
(
prv
[
2
],
[
1
,
LATENT_SIZE
])
if
get_attention
:
if
get_attention
:
return
{
"
input
"
:
prv
[
0
]}
return
{
"
input
"
:
prv
[
0
]}
else
:
else
:
#When latent_add is an input, the input image has no effect
return
{
"
input
"
:
prv
[
0
],
"
lat
_mo
d
"
:
lat_mod
}
return
{
"
input
"
:
prv
[
0
],
"
lat
ent_ad
d
"
:
lat_mod
}
prd
=
estimator
.
predict
(
predict_input_fn_2
)
prd
=
estimator
.
predict
(
predict_input_fn_2
)
if
get_attention
:
if
get_attention
:
...
@@ -634,21 +635,43 @@ def main(inputs):
...
@@ -634,21 +635,43 @@ def main(inputs):
if
maximum
>
0
:
if
maximum
>
0
:
lp
=
lp
/
maximum
lp
=
lp
/
maximum
im
=
Image
.
fromarray
(
np
.
uint8
(
lp
*
255
),
"
L
"
)
im
=
Image
.
fromarray
(
np
.
uint8
(
lp
*
255
),
"
L
"
)
im
.
save
(
"
att/attention_{}.png
"
.
format
(
str
(
x
).
zfill
(
4
)))
im
.
save
(
"
att/attention_{}.png
"
.
format
(
str
(
x
).
zfill
(
4
)))
else
:
else
:
for
i
in
range
(
0
,
40
):
for
i
in
range
(
0
,
40
0
):
p
=
next
(
prd
)
p
=
next
(
prd
)
lp
=
p
[
"
output
"
]
lp
=
p
[
"
output
"
]
lp
=
np
.
reshape
(
lp
,
[
INPUT_SIZE
,
INPUT_SIZE
,
INPUT_CHANNELS
])
lp
=
np
.
reshape
(
lp
,
[
INPUT_SIZE
,
INPUT_SIZE
,
INPUT_CHANNELS
])
lp
=
np
.
clip
(
lp
,
-
1
,
1
)
lp
=
np
.
clip
(
lp
,
-
1
,
1
)
im
=
Image
.
fromarray
(
np
.
uint8
(
lp
*
127.5
+
127.5
),
"
RGB
"
)
im
=
Image
.
fromarray
(
np
.
uint8
(
lp
*
127.5
+
127.5
),
"
RGB
"
)
im
.
save
(
"
o/output_{}.png
"
.
format
(
str
(
i
).
zfill
(
4
)))
im
.
save
(
"
o/output_{}.png
"
.
format
(
str
(
i
).
zfill
(
4
)))
if
save_attention
:
att
=
p
[
"
attention
"
]
att_size
=
att
.
shape
[
-
2
]
att_channels
=
att
.
shape
[
-
1
]
lp
=
np
.
reshape
(
att
[...],
[
att_size
,
att_size
,
att_channels
])
lp
=
lp
-
np
.
amin
(
lp
,
keepdims
=
True
,
axis
=
(
0
,
1
))
maximum
=
np
.
amax
(
lp
,
keepdims
=
True
,
axis
=
(
0
,
1
))
lp
=
lp
/
maximum
lp
=
np
.
nan_to_num
(
lp
)
lp
=
np
.
mean
(
lp
,
axis
=-
1
)
lp
=
lp
-
np
.
amin
(
lp
)
lp
=
lp
/
np
.
amax
(
lp
)
lp
=
np
.
nan_to_num
(
lp
)
lp
=
np
.
uint8
(
lp
*
254
)
im
=
Image
.
fromarray
(
lp
,
"
L
"
)
im
.
save
(
"
att/attention_{}.png
"
.
format
(
str
(
i
).
zfill
(
4
)))
im2
=
Image
.
fromarray
(
np
.
uint8
(
cur
*
127.5
+
127.5
),
"
RGB
"
)
# For saving the input image
im2
.
save
(
"
o/orig_{}.png
"
.
format
(
str
(
i
).
zfill
(
4
)))
#im2 = Image.fromarray(np.uint8(cur*127.5+127.5), "RGB")
#im2.save("o/orig_{}.png".format(str(i).zfill(4)))
else
:
else
:
if
len
(
inputs
)
>
0
:
if
len
(
inputs
)
>
0
:
...
...
This diff is collapsed.
Click to expand it.
attention_comparison.mp4
0 → 100644
+
0
−
0
View file @
5b5f60ae
File added
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment