Sentence Reconstruction¶
Name and ID: Angelo Galavotti 0001103433
This notebook contains the submission for the Deep Learning of 14/06/2023.
Description of the task¶
Take in input a sequence of words corresponding to a random permutation of a given english sentence, and reconstruct the original sentence.
The output can be either produced in a single shot, or through an iterative (autoregressive) loop generating a single token at a time.
CONSTRAINTS:
- No pretrained model can be used.
- The neural network models should have less than 20M parameters.
Solution approach¶
To compute a valid solution, I've decided to adopt a model which makes use of Transformers and Multi-head attention.
In this notebook, I will describe the most important steps of the whole approach. Additionally, at the end of the notebook, I will briefly state about my previous attempts.
Downloading the dataset¶
!pip install datasets
!pip3 install apache-beam
Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (2.1.0) Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.23.5) Requirement already satisfied: pyarrow>=5.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (9.0.0) Requirement already satisfied: dill in /opt/conda/lib/python3.10/site-packages (from datasets) (0.3.1.1) Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets) (1.5.3) Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (2.28.2) Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (4.64.1) Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets) (3.2.0) Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets) (0.70.14) Requirement already satisfied: fsspec[http]>=2021.05.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (2023.5.0) Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets) (3.8.4) Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.14.1) Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets) (21.3) Requirement already satisfied: responses<0.19 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.18.0) Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (23.1.0) Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (2.1.1) Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.2) Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.1) Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.3) Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.12.0) Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (5.4.1) Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.5.0) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->datasets) (3.0.9) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.4) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (1.26.15) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2023.5.7) Collecting dill (from datasets) Using cached dill-0.3.6-py3-none-any.whl (110 kB) Requirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2023.3) Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0) Installing collected packages: dill Attempting uninstall: dill Found existing installation: dill 0.3.1.1 Uninstalling dill-0.3.1.1: Successfully uninstalled dill-0.3.1.1 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.6 which is incompatible. pymc3 3.11.5 requires numpy<1.22.2,>=1.15.0, but you have numpy 1.23.5 which is incompatible. pymc3 3.11.5 requires scipy<1.8.0,>=1.7.3, but you have scipy 1.10.1 which is incompatible. Successfully installed dill-0.3.6 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv Requirement already satisfied: apache-beam in /opt/conda/lib/python3.10/site-packages (2.46.0) Requirement already satisfied: protobuf<4,>3.12.2 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (3.20.3) Requirement already satisfied: crcmod<2.0,>=1.7 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (1.7) Requirement already satisfied: orjson<4.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (3.8.12) Collecting dill<0.3.2,>=0.3.1.1 (from apache-beam) Using cached dill-0.3.1.1-py3-none-any.whl Requirement already satisfied: cloudpickle~=2.2.1 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (2.2.1) Requirement already satisfied: fastavro<2,>=0.23.6 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (1.7.4) Requirement already satisfied: fasteners<1.0,>=0.3 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (0.18) Requirement already satisfied: grpcio!=1.48.0,<2,>=1.33.1 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (1.51.1) Requirement already satisfied: hdfs<3.0.0,>=2.1.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (2.7.0) Requirement already satisfied: httplib2<0.22.0,>=0.8 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (0.21.0) Requirement already satisfied: numpy<1.25.0,>=1.14.3 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (1.23.5) Requirement already satisfied: objsize<0.7.0,>=0.6.1 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (0.6.1) Requirement already satisfied: pymongo<4.0.0,>=3.8.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (3.13.0) Requirement already satisfied: proto-plus<2,>=1.7.1 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (1.22.2) Requirement already satisfied: pydot<2,>=1.2.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (1.4.2) Requirement already satisfied: python-dateutil<3,>=2.8.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (2.8.2) Requirement already satisfied: pytz>=2018.3 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (2023.3) Requirement already satisfied: regex>=2020.6.8 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (2023.5.5) Requirement already satisfied: requests<3.0.0,>=2.24.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (2.28.2) Requirement already satisfied: typing-extensions>=3.7.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (4.5.0) Requirement already satisfied: zstandard<1,>=0.18.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (0.19.0) Requirement already satisfied: pyarrow<10.0.0,>=3.0.0 in /opt/conda/lib/python3.10/site-packages (from apache-beam) (9.0.0) Requirement already satisfied: docopt in /opt/conda/lib/python3.10/site-packages (from hdfs<3.0.0,>=2.1.0->apache-beam) (0.6.2) Requirement already satisfied: six>=1.9.0 in /opt/conda/lib/python3.10/site-packages (from hdfs<3.0.0,>=2.1.0->apache-beam) (1.16.0) Requirement already satisfied: pyparsing!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2 in /opt/conda/lib/python3.10/site-packages (from httplib2<0.22.0,>=0.8->apache-beam) (3.0.9) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests<3.0.0,>=2.24.0->apache-beam) (2.1.1) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests<3.0.0,>=2.24.0->apache-beam) (3.4) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests<3.0.0,>=2.24.0->apache-beam) (1.26.15) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests<3.0.0,>=2.24.0->apache-beam) (2023.5.7) Installing collected packages: dill Attempting uninstall: dill Found existing installation: dill 0.3.6 Uninstalling dill-0.3.6: Successfully uninstalled dill-0.3.6 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. multiprocess 0.70.14 requires dill>=0.3.6, but you have dill 0.3.1.1 which is incompatible. pathos 0.3.0 requires dill>=0.3.6, but you have dill 0.3.1.1 which is incompatible. pymc3 3.11.5 requires numpy<1.22.2,>=1.15.0, but you have numpy 1.23.5 which is incompatible. pymc3 3.11.5 requires scipy<1.8.0,>=1.7.3, but you have scipy 1.10.1 which is incompatible. Successfully installed dill-0.3.1.1 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
from random import Random
# Instantiate the Random instance with random seed = 42 to ensure reproducibility
randomizer = Random(42)
!pip install gdown
from keras.preprocessing.text import Tokenizer
from keras.utils import to_categorical, pad_sequences
import numpy as np
import pickle
import gdown
import random
Requirement already satisfied: gdown in /opt/conda/lib/python3.10/site-packages (4.7.1) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from gdown) (3.12.0) Requirement already satisfied: requests[socks] in /opt/conda/lib/python3.10/site-packages (from gdown) (2.28.2) Requirement already satisfied: six in /opt/conda/lib/python3.10/site-packages (from gdown) (1.16.0) Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from gdown) (4.64.1) Requirement already satisfied: beautifulsoup4 in /opt/conda/lib/python3.10/site-packages (from gdown) (4.12.2) Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.10/site-packages (from beautifulsoup4->gdown) (2.3.2.post1) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (2.1.1) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (3.4) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (1.26.15) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (2023.5.7) Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown) (1.7.1) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
from datasets import load_dataset
dataset = load_dataset("wikipedia", "20220301.simple")
data = dataset['train'][:20000]['text']
0%| | 0/1 [00:00<?, ?it/s]
Tokenization¶
#run this cell only the first time to create and save the tokenizer and the date
dump = True
tokenizer = Tokenizer(split=' ', filters='!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n', num_words=10000, oov_token='<unk>')
corpus = []
# Split of each piece of text into sentences
for elem in data:
corpus += elem.lower().replace("\n", "").split(".")[:]
print("corpus dim: ",len(corpus))
#add a start and an end token
corpus = ['<start> '+s+' <end>' for s in corpus]
# Tokenization
tokenizer.fit_on_texts(corpus)
#print(tokenizer.word_index['<unk>'])
if dump:
with open('tokenizer.pickle', 'wb') as handle:
pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
original_data = [sen for sen in tokenizer.texts_to_sequences(corpus) if (len(sen) <= 32 and len(sen)>4 and not(1 in sen))]
if dump:
with open('original.pickle', 'wb') as handle:
pickle.dump(original_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
print ("filtered sentences: ",len(original_data))
sos = tokenizer.word_index['<start>']
eos = tokenizer.word_index['<end>']
#print(eos)
#print(tokenizer.index_word[sos])
tokenizer.word_index['<pad>'] = 0
tokenizer.index_word[0] = '<pad>'
# dimension of the vocabulary of tokens
vocab_dimension = len(tokenizer.word_index) + 1
corpus dim: 510023 filtered sentences: 137301
shuffled_data = [random.sample(s[1:-1],len(s)-2) for s in original_data]
shuffled_data = [[sos]+s+[eos] for s in shuffled_data] # shuffled_data is an input of the model
target_data = [s[1:] for s in original_data] # target_data is the same as original data but offset by one timestep
from sklearn.model_selection import train_test_split
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(original_data, shuffled_data, target_data, test_size = 0.3, random_state = 42)
Score function¶
from difflib import SequenceMatcher
def score(s,p):
match = SequenceMatcher(None, s, p).find_longest_match()
#print(match.size)
return (match.size/max(len(s),len(p)))
def clean_sentence(x):
x = x.replace('<start>', '').replace('<end>', '').replace('<pad>', '').strip()
return x
from difflib import SequenceMatcher
def score(s,p):
match = SequenceMatcher(None, s, p).find_longest_match()
#print(match.size)
return (match.size/max(len(s),len(p)))
def clean_sentence(x):
x = x.replace('<start>', '').replace('<end>', '').replace('<pad>', '').strip()
return x
i = np.random.randint(len(original_data))
print("original sentence: ",original_data[i])
print("shuffled sentecen: ",shuffled_data[i])
original sentence: [2, 4, 780, 14, 5, 60, 829, 6, 1043, 20, 188, 1520, 21, 191, 31, 9, 75, 172, 1520, 18, 56, 23, 2053, 1777, 3] shuffled sentecen: [2, 9, 60, 31, 18, 23, 2053, 191, 780, 172, 188, 14, 75, 56, 6, 1777, 20, 1520, 1520, 21, 4, 5, 829, 1043, 3]
Dataset padding/formatting¶
max_sequence_len = max([len(x) for x in original_data])
x_train = pad_sequences(x_train, maxlen=max_sequence_len, padding='post')
x_test = pad_sequences(x_test, maxlen=max_sequence_len, padding='post')
c_train = pad_sequences(c_train, maxlen=max_sequence_len, padding='post')
c_test = pad_sequences(c_test, maxlen=max_sequence_len, padding='post')
y_train = pad_sequences(y_train, maxlen=max_sequence_len, padding='post')
y_test = pad_sequences(y_test, maxlen=max_sequence_len, padding='post')
print("x_train size:", len(x_train))
assert(len(x_train)==len(c_train)==len(y_train))
print(len(x_train))
print(len(c_train))
print(len(y_train))
x_train size: 96110 96110 96110 96110
i = np.random.randint(len(x_train))
print("original sentence: ",tokenizer.sequences_to_texts([x_train[i]])[0])
print("shuffled sentence: ",tokenizer.sequences_to_texts([c_train[i]])[0])
original sentence: <start> in this way people can read many articles easily but it is illegal <end> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> shuffled sentence: <start> many way people articles it read in easily but illegal this is can <end> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
The model¶
After some attempts using RNNs and LSTMs, I decided to opt for a different model. This is due to many reasons, mainly:
They capture hidden dependendencies in data.
They make no assumptions about the spatial relationships across data.
The latter concept was essential for the performance of this model. In fact, the model should behave the same regardless of the ordering of the inputs: a property that is not ensured by LSTMs.
Building the layers¶
The model is comprised of this type of layers:
- Base attention layer
- Cross attention layer
- Global and Causal self attention layer
- Feed Forward layer
Let's look over their code and their inner functioning.
Base Attention Layer¶
The Base attention layer is comprised of a Multi-Nead attention layer, with a Add & Norm layer.
In particular, each attention head can specialize in different aspects or dependendecies of the sequence it receives.
import tensorflow as tf
from keras.layers import Embedding
class BaseAttention(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__()
self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
self.layernorm = tf.keras.layers.LayerNormalization()
self.add = tf.keras.layers.Add()
Cross-attention layer¶
The cross-attention layer connects the encoder and the decoder of the model by means of a context vector.
class CrossAttention(BaseAttention):
def call(self, x, context):
attn_output, attn_scores = self.mha(
query=x,
key=context,
value=context,
return_attention_scores=True)
# Cache the attention scores
self.last_attn_scores = attn_scores
x = self.add([x, attn_output])
x = self.layernorm(x)
return x
Global self attention layer¶
This layer is responsible for processing/generating the context sequence, and propagating information along its length.
class GlobalSelfAttention(BaseAttention):
def call(self, x):
attn_output = self.mha(
query=x,
value=x,
key=x)
x = self.add([x, attn_output])
x = self.layernorm(x)
return x
Causal self attention layer¶
This layer does the same thing as the Global Attetion layer but for the output sequence.
As a matter of fact, their structure is very similar.
class CausalSelfAttention(BaseAttention):
def call(self, x):
attn_output = self.mha(
query=x,
value=x,
key=x,
use_causal_mask = True)
x = self.add([x, attn_output])
x = self.layernorm(x)
return x
Feed forward layer¶
This layer is comprised of two dense layers with relu activation, as well as a dropout layer, which helps in reducing overfitting.
class FeedForward(tf.keras.layers.Layer):
def __init__(self, d_model, dff, dropout_rate=0.1):
super().__init__()
self.seq = tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model),
tf.keras.layers.Dropout(dropout_rate)
])
self.add = tf.keras.layers.Add()
self.layer_norm = tf.keras.layers.LayerNormalization()
def call(self, x):
x = self.add([x, self.seq(x)])
x = self.layer_norm(x)
return x
Positional Embedding Layer¶
A normal embedding layer converts the input into a vector, in order to be given as input to a neural network.
A positional embedding makes use of a positional encoding in order to give importance to the position of a word in a sequence.
def positional_encoding(length, depth):
depth = depth/2
positions = np.arange(length)[:, np.newaxis]
depths = np.arange(depth)[np.newaxis, :]/depth
angle_rates = 1 / (10000**depths)
angle_rads = positions * angle_rates
pos_encoding = np.concatenate(
[np.sin(angle_rads), np.cos(angle_rads)],
axis=-1)
return tf.cast(pos_encoding, dtype=tf.float32)
def PositionalEmbedding(length, depth):
depth = depth/2
positions = np.arange(length)[:, np.newaxis]
depths = np.arange(depth)[np.newaxis, :]/depth
angle_rates = 1 / (10000**depths)
angle_rads = positions * angle_rates
pos_encoding = np.concatenate(
[np.sin(angle_rads), np.cos(angle_rads)],
axis=-1)
return tf.cast(pos_encoding, dtype=tf.float32)
class PositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, vocab_size, d_model):
super().__init__()
self.d_model = d_model
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
self.pos_encoding = positional_encoding(length=2048, depth=d_model)
def compute_mask(self, *args, **kwargs):
return self.embedding.compute_mask(*args, **kwargs)
def call(self, x):
length = tf.shape(x)[1]
x = self.embedding(x)
# This factor sets the relative scale of the embedding and positonal_encoding.
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x = x + self.pos_encoding[tf.newaxis, :length, :]
return x
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
super().__init__()
self.self_attention = GlobalSelfAttention(
num_heads=num_heads,
key_dim=d_model,
dropout=dropout_rate)
self.ffn = FeedForward(d_model, dff)
def call(self, x):
x = self.self_attention(x)
x = self.ffn(x)
return x
In the encoder, the positional embedding layer is removed, and is swapped with a normal embedding layer.
This isn't without any reason: without the positional embedding, our input is seen as a "bag of words", in which the order of each word is not taken into account.
This is exactly what we want: in fact, the model should behave in the same way with each possible sequence of the same set of words.
class Encoder(tf.keras.layers.Layer):
def __init__(self, *, num_layers, d_model, num_heads,
dff, vocab_size, dropout_rate=0.1):
super().__init__()
self.d_model = d_model
self.num_layers = num_layers
# No positional embedding, since we need this model to treat input as BoW
self.embedding = Embedding(input_dim=vocab_size, output_dim=d_model)
self.enc_layers = [
EncoderLayer(d_model=d_model,
num_heads=num_heads,
dff=dff,
dropout_rate=dropout_rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(dropout_rate)
def call(self, x):
x = self.embedding(x)
x = self.dropout(x)
for i in range(self.num_layers):
x = self.enc_layers[i](x)
return x
Decoder¶
The structure of the decoder is very similar to the structure of the encoder, aside from a few differences.
Decoder layer¶
Each encoding layer is made of a Causal self attention layer and a feed forward layer.
In addition, it incorporates the cross attention layer, to receive the context vector.
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self,
*,
d_model,
num_heads,
dff,
dropout_rate=0.1):
super(DecoderLayer, self).__init__()
self.causal_self_attention = CausalSelfAttention(
num_heads=num_heads,
key_dim=d_model,
dropout=dropout_rate)
self.cross_attention = CrossAttention(
num_heads=num_heads,
key_dim=d_model,
dropout=dropout_rate)
self.ffn = FeedForward(d_model, dff)
def call(self, x, context):
x = self.causal_self_attention(x=x)
x = self.cross_attention(x=x, context=context)
# Cache the last attention scores
self.last_attn_scores = self.cross_attention.last_attn_scores
x = self.ffn(x)
return x
As opposed to the encoder, in the decoder we have a positional embedding, since, during teacher forcing, it must capture the underlying positional information embedded in the sentence.
class Decoder(tf.keras.layers.Layer):
def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
dropout_rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
d_model=d_model)
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.dec_layers = [
DecoderLayer(d_model=d_model, num_heads=num_heads,
dff=dff, dropout_rate=dropout_rate)
for _ in range(num_layers)]
self.last_attn_scores = None
def call(self, x, context):
x = self.pos_embedding(x)
x = self.dropout(x)
for i in range(self.num_layers):
x = self.dec_layers[i](x, context)
self.last_attn_scores = self.dec_layers[-1].last_attn_scores
return x
Final transformer¶
Putting everything together, we obtain the transformer.
We are also adding an additonal final Dense layer, which converts the resulting vector at each location into output token probabilities.
class Transformer(tf.keras.Model):
def __init__(self, *, num_layers, d_model, num_heads, dff,
input_vocab_size, target_vocab_size, dropout_rate=0.1):
super().__init__()
self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
num_heads=num_heads, dff=dff,
vocab_size=input_vocab_size,
dropout_rate=dropout_rate)
self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
num_heads=num_heads, dff=dff,
vocab_size=target_vocab_size,
dropout_rate=dropout_rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
def call(self, inputs):
# computing and giving context to decoder
context, x = inputs
context = self.encoder(context)
x = self.decoder(x, context)
# Final linear layer output.
logits = self.final_layer(x)
try:
# Drop the keras mask, so it doesn't scale the losses/metrics.
del logits._keras_mask
except AttributeError:
pass
return logits
Instatiating the model¶
The model is instantiated with the following parameters.
Each of them was chosen through trial and error, by training different models with different combinations of parameters.
Some of the most influential were the number of heads and the dropout rate.
- The number of heads influences how the model captures the underlying dependencies in sequences.
- The droupout rate influences how much the model is subject to overfitting and underfitting.
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.2
transformer = Transformer(
num_layers=num_layers,
d_model=d_model,
num_heads=num_heads,
dff=dff,
input_vocab_size=10_000,
target_vocab_size=10_000,
dropout_rate=dropout_rate)
Training the model¶
The model uses an Adam optimizer. The learning rate schedule was chosen according to the paper "Attention is all you need" in which Transformers where first introduced.
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super().__init__()
self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
step = tf.cast(step, dtype=tf.float32)
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
epsilon=1e-9)
Loss function and metrics¶
The sparse categorical cross-entropy and accuracy are extended to include a padding mask.
def masked_loss(label, pred):
mask = label != 0
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
loss = loss_object(label, pred)
mask = tf.cast(mask, dtype=loss.dtype)
loss *= mask
loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
return loss
def masked_accuracy(label, pred):
pred = tf.argmax(pred, axis=2)
label = tf.cast(label, pred.dtype)
match = label == pred
mask = label != 0
match = match & mask
match = tf.cast(match, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.reduce_sum(match)/tf.reduce_sum(mask)
Compiling and training the model¶
The model is built and set up for training.
transformer.compile(
loss=masked_loss,
optimizer=optimizer,
metrics=[masked_accuracy]
)
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
early_stopping = EarlyStopping(monitor='val_masked_accuracy', mode='max', verbose=1, patience=5)
epochs = 50
batch_size = 256
transformer.fit(
(c_train, x_train),
y_train,
epochs=epochs,
batch_size=batch_size,
callbacks = [early_stopping],
validation_split = 0.05
)
Epoch 1/50 357/357 [==============================] - 186s 416ms/step - loss: 8.2259 - masked_accuracy: 0.1062 - val_loss: 6.9074 - val_masked_accuracy: 0.1428 Epoch 2/50 357/357 [==============================] - 117s 328ms/step - loss: 6.1884 - masked_accuracy: 0.1925 - val_loss: 5.4053 - val_masked_accuracy: 0.2822 Epoch 3/50 357/357 [==============================] - 115s 321ms/step - loss: 4.9012 - masked_accuracy: 0.3262 - val_loss: 4.2626 - val_masked_accuracy: 0.3935 Epoch 4/50 357/357 [==============================] - 114s 320ms/step - loss: 3.9539 - masked_accuracy: 0.4154 - val_loss: 3.4291 - val_masked_accuracy: 0.4757 Epoch 5/50 357/357 [==============================] - 113s 315ms/step - loss: 3.1970 - masked_accuracy: 0.4888 - val_loss: 2.7783 - val_masked_accuracy: 0.5496 Epoch 6/50 357/357 [==============================] - 113s 316ms/step - loss: 2.5962 - masked_accuracy: 0.5464 - val_loss: 2.2657 - val_masked_accuracy: 0.5936 Epoch 7/50 357/357 [==============================] - 112s 313ms/step - loss: 2.1597 - masked_accuracy: 0.5874 - val_loss: 1.9209 - val_masked_accuracy: 0.6282 Epoch 8/50 357/357 [==============================] - 112s 313ms/step - loss: 1.8573 - masked_accuracy: 0.6175 - val_loss: 1.7337 - val_masked_accuracy: 0.6507 Epoch 9/50 357/357 [==============================] - 112s 313ms/step - loss: 1.6527 - masked_accuracy: 0.6399 - val_loss: 1.6048 - val_masked_accuracy: 0.6655 Epoch 10/50 357/357 [==============================] - 112s 314ms/step - loss: 1.5165 - masked_accuracy: 0.6564 - val_loss: 1.5380 - val_masked_accuracy: 0.6743 Epoch 11/50 357/357 [==============================] - 112s 313ms/step - loss: 1.4235 - masked_accuracy: 0.6691 - val_loss: 1.4807 - val_masked_accuracy: 0.6798 Epoch 12/50 357/357 [==============================] - 112s 315ms/step - loss: 1.3379 - masked_accuracy: 0.6815 - val_loss: 1.4253 - val_masked_accuracy: 0.6912 Epoch 13/50 357/357 [==============================] - 112s 314ms/step - loss: 1.2310 - masked_accuracy: 0.6982 - val_loss: 1.3956 - val_masked_accuracy: 0.6989 Epoch 14/50 357/357 [==============================] - 112s 314ms/step - loss: 1.1418 - masked_accuracy: 0.7134 - val_loss: 1.3350 - val_masked_accuracy: 0.7048 Epoch 15/50 357/357 [==============================] - 112s 314ms/step - loss: 1.0642 - masked_accuracy: 0.7262 - val_loss: 1.2882 - val_masked_accuracy: 0.7184 Epoch 16/50 357/357 [==============================] - 112s 315ms/step - loss: 0.9991 - masked_accuracy: 0.7383 - val_loss: 1.2809 - val_masked_accuracy: 0.7175 Epoch 17/50 357/357 [==============================] - 112s 313ms/step - loss: 0.9416 - masked_accuracy: 0.7494 - val_loss: 1.2556 - val_masked_accuracy: 0.7252 Epoch 18/50 357/357 [==============================] - 112s 314ms/step - loss: 0.8897 - masked_accuracy: 0.7596 - val_loss: 1.2491 - val_masked_accuracy: 0.7253 Epoch 19/50 357/357 [==============================] - 111s 312ms/step - loss: 0.8467 - masked_accuracy: 0.7682 - val_loss: 1.2536 - val_masked_accuracy: 0.7272 Epoch 20/50 357/357 [==============================] - 112s 313ms/step - loss: 0.8083 - masked_accuracy: 0.7764 - val_loss: 1.2510 - val_masked_accuracy: 0.7281 Epoch 21/50 357/357 [==============================] - 112s 313ms/step - loss: 0.7719 - masked_accuracy: 0.7842 - val_loss: 1.2376 - val_masked_accuracy: 0.7318 Epoch 22/50 357/357 [==============================] - 113s 316ms/step - loss: 0.7393 - masked_accuracy: 0.7908 - val_loss: 1.2476 - val_masked_accuracy: 0.7366 Epoch 23/50 357/357 [==============================] - 112s 313ms/step - loss: 0.7090 - masked_accuracy: 0.7974 - val_loss: 1.2371 - val_masked_accuracy: 0.7368 Epoch 24/50 357/357 [==============================] - 112s 313ms/step - loss: 0.6834 - masked_accuracy: 0.8034 - val_loss: 1.2363 - val_masked_accuracy: 0.7394 Epoch 25/50 357/357 [==============================] - 111s 312ms/step - loss: 0.6574 - masked_accuracy: 0.8096 - val_loss: 1.2489 - val_masked_accuracy: 0.7362 Epoch 26/50 357/357 [==============================] - 112s 313ms/step - loss: 0.6345 - masked_accuracy: 0.8148 - val_loss: 1.2341 - val_masked_accuracy: 0.7429 Epoch 27/50 357/357 [==============================] - 112s 313ms/step - loss: 0.6137 - masked_accuracy: 0.8202 - val_loss: 1.2481 - val_masked_accuracy: 0.7405 Epoch 28/50 357/357 [==============================] - 112s 313ms/step - loss: 0.5922 - masked_accuracy: 0.8249 - val_loss: 1.2349 - val_masked_accuracy: 0.7438 Epoch 29/50 357/357 [==============================] - 112s 314ms/step - loss: 0.5745 - masked_accuracy: 0.8295 - val_loss: 1.2563 - val_masked_accuracy: 0.7436 Epoch 30/50 357/357 [==============================] - 112s 314ms/step - loss: 0.5568 - masked_accuracy: 0.8340 - val_loss: 1.2562 - val_masked_accuracy: 0.7413 Epoch 31/50 357/357 [==============================] - 112s 314ms/step - loss: 0.5404 - masked_accuracy: 0.8378 - val_loss: 1.2522 - val_masked_accuracy: 0.7441 Epoch 32/50 357/357 [==============================] - 112s 314ms/step - loss: 0.5238 - masked_accuracy: 0.8423 - val_loss: 1.2784 - val_masked_accuracy: 0.7446 Epoch 33/50 357/357 [==============================] - 113s 315ms/step - loss: 0.5096 - masked_accuracy: 0.8456 - val_loss: 1.2667 - val_masked_accuracy: 0.7470 Epoch 34/50 357/357 [==============================] - 112s 312ms/step - loss: 0.4976 - masked_accuracy: 0.8487 - val_loss: 1.2776 - val_masked_accuracy: 0.7446 Epoch 35/50 357/357 [==============================] - 111s 312ms/step - loss: 0.4852 - masked_accuracy: 0.8520 - val_loss: 1.2808 - val_masked_accuracy: 0.7465 Epoch 36/50 357/357 [==============================] - 112s 313ms/step - loss: 0.4711 - masked_accuracy: 0.8558 - val_loss: 1.2838 - val_masked_accuracy: 0.7484 Epoch 37/50 357/357 [==============================] - 111s 312ms/step - loss: 0.4605 - masked_accuracy: 0.8591 - val_loss: 1.2955 - val_masked_accuracy: 0.7491 Epoch 38/50 357/357 [==============================] - 111s 312ms/step - loss: 0.4482 - masked_accuracy: 0.8627 - val_loss: 1.3007 - val_masked_accuracy: 0.7466 Epoch 39/50 357/357 [==============================] - 112s 314ms/step - loss: 0.4377 - masked_accuracy: 0.8650 - val_loss: 1.3004 - val_masked_accuracy: 0.7492 Epoch 40/50 357/357 [==============================] - 112s 313ms/step - loss: 0.4298 - masked_accuracy: 0.8672 - val_loss: 1.3008 - val_masked_accuracy: 0.7489 Epoch 41/50 357/357 [==============================] - 112s 313ms/step - loss: 0.4200 - masked_accuracy: 0.8699 - val_loss: 1.3129 - val_masked_accuracy: 0.7518 Epoch 42/50 357/357 [==============================] - 111s 311ms/step - loss: 0.4119 - masked_accuracy: 0.8723 - val_loss: 1.3206 - val_masked_accuracy: 0.7511 Epoch 43/50 357/357 [==============================] - 112s 313ms/step - loss: 0.4014 - masked_accuracy: 0.8753 - val_loss: 1.3289 - val_masked_accuracy: 0.7501 Epoch 44/50 357/357 [==============================] - 112s 314ms/step - loss: 0.3938 - masked_accuracy: 0.8773 - val_loss: 1.3303 - val_masked_accuracy: 0.7502 Epoch 45/50 357/357 [==============================] - 112s 313ms/step - loss: 0.3850 - masked_accuracy: 0.8797 - val_loss: 1.3414 - val_masked_accuracy: 0.7502 Epoch 46/50 357/357 [==============================] - 112s 313ms/step - loss: 0.3779 - masked_accuracy: 0.8817 - val_loss: 1.3453 - val_masked_accuracy: 0.7521 Epoch 47/50 357/357 [==============================] - 111s 312ms/step - loss: 0.3707 - masked_accuracy: 0.8838 - val_loss: 1.3548 - val_masked_accuracy: 0.7523 Epoch 48/50 357/357 [==============================] - 112s 312ms/step - loss: 0.3625 - masked_accuracy: 0.8861 - val_loss: 1.3469 - val_masked_accuracy: 0.7495 Epoch 49/50 357/357 [==============================] - 111s 312ms/step - loss: 0.3562 - masked_accuracy: 0.8884 - val_loss: 1.3695 - val_masked_accuracy: 0.7518 Epoch 50/50 357/357 [==============================] - 112s 313ms/step - loss: 0.3505 - masked_accuracy: 0.8897 - val_loss: 1.3659 - val_masked_accuracy: 0.7499
<keras.callbacks.History at 0x7f122acaadd0>
transformer.summary()
Model: "transformer_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= encoder_1 (Encoder) multiple 3918848 decoder (Decoder) multiple 6029824 dense_16 (Dense) multiple 1290000 ================================================================= Total params: 11,238,672 Trainable params: 11,238,672 Non-trainable params: 0 _________________________________________________________________
### This code was used in order to load the saved model weights. It should be ignored.
# from google.colab import drive
# drive.mount('/content/drive')
# transformer.save_weights('drive/MyDrive/saved_model_weights_8_128_02/my_model_weights')
Translator module¶
This module is responsible for wrapping the computation of the transformer. In essence, it generates a bag of words from a batch of shuffled sentences, and gradually computes the index of the best word prediction given by the transformer.
class Translator(tf.Module):
def __init__(self, transformer, tokenizer):
self.transformer = transformer
self.tokenizer = tokenizer
def __call__(self, sentences, max_length=max_sequence_len):
batch_size = sentences.shape[0]
# generate word list for each sentence
bow = [[word for word in sentence if word not in [sos, eos, 0]] for sentence in sentences]
# starting vector for prediction, it contains the sos index
output = [[self.tokenizer.word_index['<start>']] for _ in range(batch_size)]
# during inference, output will be filled with the final sentence.
for i in range(1, max_length):
# (enc_input, dec_input)
predictions = np.array(self.transformer((np.array(sentences), np.array(output))))
# remove useless dimensions
predictions = predictions[:, -1, :]
for j in range(batch_size):
if len(bow[j]) == 0:
# no more words to use
cand_token = eos
else:
# choose index with highest score
s_prediction = predictions[j, np.array(bow[j])]
cand_index = np.argmax(s_prediction)
cand_token = bow[j][cand_index]
del bow[j][cand_index]
output[j].append(cand_token)
return output
translator = Translator(transformer, tokenizer)
Computing the score¶
Now, we effectively test our translator and compute the score.
To do that, we compute a score on 3K generated samples.
Since computing the score directly on 3K batches could give us some problems in Colab, it is computed on batches of 300 samples each.
Then, the total score computed as the average between batches.
score_batch_size = 100
total_test_size = 3000
score_ = 0
for i in range(total_test_size//score_batch_size):
ordered = x_test[i*score_batch_size:(i+1)*score_batch_size]
shuffled = c_test[i*score_batch_size:(i+1)*score_batch_size]
y_pred = translator(shuffled)
b_score = 0 # score associated with each batch
pred_sentences = tokenizer.sequences_to_texts(y_pred)
original_sentences = tokenizer.sequences_to_texts(ordered)
for j in range(score_batch_size) :
b_score += score(clean_sentence(original_sentences[j]), clean_sentence(pred_sentences[j]))
score_ += b_score
print("\n====BATCH OVER====")
print("Score as of batch ", i, ": ", score_/((i+1)*score_batch_size))
score_ = score_/total_test_size
print("\n====ALL OVER====")
print("Final score: ", score_)
====BATCH OVER===== Score as of batch 0 : 0.4948857841637906 ====BATCH OVER===== Score as of batch 1 : 0.5061785304182481 ====BATCH OVER===== Score as of batch 2 : 0.5058148684303929 ====BATCH OVER===== Score as of batch 3 : 0.5027345913851916 ====BATCH OVER===== Score as of batch 4 : 0.517508295140673 ====BATCH OVER===== Score as of batch 5 : 0.530376627646052 ====BATCH OVER===== Score as of batch 6 : 0.5344367026985729 ====BATCH OVER===== Score as of batch 7 : 0.5379621676018749 ====BATCH OVER===== Score as of batch 8 : 0.542301335438453 ====BATCH OVER===== Score as of batch 9 : 0.54320980641855 ====BATCH OVER===== Score as of batch 10 : 0.5395313403534979 ====BATCH OVER===== Score as of batch 11 : 0.5377743859102094 ====BATCH OVER===== Score as of batch 12 : 0.5375765857504796 ====BATCH OVER===== Score as of batch 13 : 0.5374369683136379 ====BATCH OVER===== Score as of batch 14 : 0.5370951441912087 ====BATCH OVER===== Score as of batch 15 : 0.5368290867132967 ====BATCH OVER===== Score as of batch 16 : 0.5333478463531035 ====BATCH OVER===== Score as of batch 17 : 0.5325379103235018 ====BATCH OVER===== Score as of batch 18 : 0.5322204622558182 ====BATCH OVER===== Score as of batch 19 : 0.5328991447097614 ====BATCH OVER===== Score as of batch 20 : 0.5308620255295186 ====BATCH OVER===== Score as of batch 21 : 0.5313577698465886 ====BATCH OVER===== Score as of batch 22 : 0.5300986523987387 ====BATCH OVER===== Score as of batch 23 : 0.5283960723109501 ====BATCH OVER===== Score as of batch 24 : 0.5272936979907543 ====BATCH OVER===== Score as of batch 25 : 0.5266510407918356 ====BATCH OVER===== Score as of batch 26 : 0.5270188445976258 ====BATCH OVER===== Score as of batch 27 : 0.5259018528396617 ====BATCH OVER===== Score as of batch 28 : 0.5279189246585928 ====BATCH OVER===== Score as of batch 29 : 0.5297519821991407 ====ALL OVER===== Final score: 0.5297519821991407
Conclusion¶
The model obtains average performance. Parameter tuning such as:
- increasing the attention heads
- increase the dropout rate
- increasing the model size
Led to similar or lower scores.
Previous attempts¶
In the previous iteration, I tried using a stack of LSTM layers in an encoder/decoder structure, using a context vector to communicate between the two. The model also made use of teacher forcing.
The result provided by this architecture were unsatisfying, with a very below average score, presumably because the model failed to capture the underlying relationship between sequences during training.
This led to the adoption of the transformer model.