-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmask.py
138 lines (111 loc) · 4.5 KB
/
mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import sys
import tensorflow as tf
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoTokenizer, TFBertForMaskedLM
# Pre-trained masked language model
MODEL = "bert-base-uncased"
# Number of predictions to generate
K = 3
# Constants for generating attention diagrams
FONT = ImageFont.truetype("redesigned-dollop/assets/fonts/OpenSans-Regular.ttf", 28)
GRID_SIZE = 40
PIXELS_PER_WORD = 200
def main():
text = input("Text: ")
# Tokenize input
tokenizer = AutoTokenizer.from_pretrained(MODEL)
inputs = tokenizer(text, return_tensors="tf")
mask_token_index = get_mask_token_index(tokenizer.mask_token_id, inputs)
if mask_token_index is None:
sys.exit(f"Input must include mask token {tokenizer.mask_token}.")
# Use model to process input
model = TFBertForMaskedLM.from_pretrained(MODEL)
result = model(**inputs, output_attentions=True)
# Generate predictions
mask_token_logits = result.logits[0, mask_token_index]
top_tokens = tf.math.top_k(mask_token_logits, K).indices.numpy()
for token in top_tokens:
print(text.replace(tokenizer.mask_token, tokenizer.decode([token])))
# Visualize attentions
visualize_attentions(inputs.tokens(), result.attentions)
def get_mask_token_index(mask_token_id, inputs):
"""
Return the index of the token with the specified `mask_token_id`, or
`None` if not present in the `inputs`.
"""
input_ids = inputs["input_ids"].numpy()[0]
for index, token_id in enumerate(input_ids):
if token_id == mask_token_id:
return index
return None
def get_color_for_attention_score(attention_score):
"""
Return a tuple of three integers representing a shade of gray for the
given `attention_score`. Each value should be in the range [0, 255].
"""
intensity = int(attention_score * 255)
return (intensity, intensity, intensity)
def visualize_attentions(tokens, attentions):
"""
Produce a graphical representation of self-attention scores.
For each attention layer, one diagram should be generated for each
attention head in the layer. Each diagram should include the list of
`tokens` in the sentence. The filename for each diagram should
include both the layer number (starting count from 1) and head number
(starting count from 1).
"""
num_layers = len(attentions)
num_heads = len(attentions[0][0])
for layer_index in range(num_layers):
for head_index in range(num_heads):
generate_diagram(
layer_index + 1,
head_index + 1,
tokens,
attentions[layer_index][0][head_index].numpy(),
)
def generate_diagram(layer_number, head_number, tokens, attention_weights):
"""
Generate a diagram representing the self-attention scores for a single
attention head. The diagram shows one row and column for each of the
`tokens`, and cells are shaded based on `attention_weights`, with lighter
cells corresponding to higher attention scores.
The diagram is saved with a filename that includes both the `layer_number`
and `head_number`.
"""
# Create new image
image_size = GRID_SIZE * len(tokens) + PIXELS_PER_WORD
img = Image.new("RGBA", (image_size, image_size), "black")
draw = ImageDraw.Draw(img)
# Draw each token onto the image
for i, token in enumerate(tokens):
# Draw token columns
token_image = Image.new("RGBA", (image_size, image_size), (0, 0, 0, 0))
token_draw = ImageDraw.Draw(token_image)
token_draw.text(
(image_size - PIXELS_PER_WORD, PIXELS_PER_WORD + i * GRID_SIZE),
token,
fill="white",
font=FONT,
)
token_image = token_image.rotate(90)
img.paste(token_image, mask=token_image)
# Draw token rows
_, _, width, _ = draw.textbbox((0, 0), token, font=FONT)
draw.text(
(PIXELS_PER_WORD - width, PIXELS_PER_WORD + i * GRID_SIZE),
token,
fill="white",
font=FONT,
)
# Draw each word
for i in range(len(tokens)):
y = PIXELS_PER_WORD + i * GRID_SIZE
for j in range(len(tokens)):
x = PIXELS_PER_WORD + j * GRID_SIZE
color = get_color_for_attention_score(attention_weights[i][j])
draw.rectangle((x, y, x + GRID_SIZE, y + GRID_SIZE), fill=color)
# Save image
img.save(f"Attention_Layer{layer_number}_Head{head_number}.png")
if __name__ == "__main__":
main()