Variational Autoencoders (VAEs)

The VAE implemented here uses the setup found in most VAE papers: a multivariate Normal distribution for the conditional distribution of the latent vectors given and input image ($q_{\phi}(z | x_i)$ in the slides) and a multivariate Bernoulli distribution for the conditional distribution of images given the latent vector ($p_{\theta}(x | z)$ in the slides). Using a Bernoulli distribution, the reconstruction loss (negative log likelihood of a data point in the output distribution) reduces to the pixel-wise binary cross-entropy. See the original VAE paper, Appendix C.1 for details.

In [ ]:
%%bash
pip install --upgrade pytorch-lightning
pip install tokenizers
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.5.5-py3-none-any.whl (525 kB)
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
Collecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.19.5)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.6.1-py3-none-any.whl (332 kB)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (3.10.0.2)
Requirement already satisfied: torch>=1.7.* in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.10.0+cu111)
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.11.1-py3-none-any.whl (132 kB)
Collecting PyYAML>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (21.3)
Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2.7.0)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.62.3)
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.23.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=17.0->pytorch-lightning) (3.0.6)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.6.1)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.42.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.35.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.17.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.37.0)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.6)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.12.0)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (57.4.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.8.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.6)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch-lightning) (1.15.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.0)
Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (4.8.2)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.6.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.0.4)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.1.1)
Collecting multidict<7.0,>=4.5
  Downloading multidict-5.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (160 kB)
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (192 kB)
Collecting async-timeout<5.0,>=4.0.0a3
  Downloading async_timeout-4.0.1-py3-none-any.whl (5.7 kB)
Collecting asynctest==0.13.0
  Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (21.2.0)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.0.8)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
Collecting aiosignal>=1.1.2
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Building wheels for collected packages: future
  Building wheel for future (setup.py): started
  Building wheel for future (setup.py): finished with status 'done'
  Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=12d80ed5409b3624f800dcf85f48261389359729cb6171c70887c6a6d9eda331
  Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b8f4ae377e2229ba0
Successfully built future
Installing collected packages: multidict, frozenlist, yarl, asynctest, async-timeout, aiosignal, fsspec, aiohttp, torchmetrics, PyYAML, pyDeprecate, future, pytorch-lightning
  Attempting uninstall: PyYAML
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
  Attempting uninstall: future
    Found existing installation: future 0.16.0
    Uninstalling future-0.16.0:
      Successfully uninstalled future-0.16.0
Successfully installed PyYAML-6.0 aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.1 asynctest-0.13.0 frozenlist-1.2.0 fsspec-2021.11.1 future-0.18.2 multidict-5.2.0 pyDeprecate-0.3.1 pytorch-lightning-1.5.5 torchmetrics-0.6.1 yarl-1.7.2
Collecting tokenizers
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
Installing collected packages: tokenizers
Successfully installed tokenizers-0.10.3
In [ ]:
from google.colab import drive 
drive.mount('/content/drive')

%cd /content/drive/My\ Drive/6.864-final-project
Mounted at /content/drive
/content/drive/My Drive/6.864-final-project
In [ ]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.autograd import Variable
from nltk.tokenize import TweetTokenizer

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything

from tokenizers import ByteLevelBPETokenizer, normalizers
from tokenizers.normalizers import NFKC, StripAccents, Lowercase, NFKD

import pandas as pd
import re
%matplotlib inline

Data Loading

Data Processing

In [ ]:
df.head()
Out[ ]:
0
0 @WeLoveYouZahir bro 😭
1 FREE RAPHAEL
2 they done arrested a ninja turtle https://t.co...
3 @picgoeshard post this https://t.co/1FiX5SvA6v
4 successfuIIy wasted 11 months of 2021
In [ ]:
from ast import literal_eval
import re
def remove_emoji(string):
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F" # emoticons
                           u"\U0001F300-\U0001F5FF" # symbols & pictographs
                           u"\U0001F680-\U0001F6FF" # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF" # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    return emoji_pattern.sub(r'', string)

df = pd.read_csv("./100ktweets.csv", encoding='utf-8',header=None)
df.rename(columns = {0:'text',}, 
            inplace = True)
df["text"] = df["text"].astype(str)
#df["text"] = df["text"].apply(remove_emoji)
df["text"] = df["text"].apply(lambda x : re.sub(r'http\S+', '', x)) #remove links
df["text"] = df["text"].apply(lambda x : re.sub(r'><', '> <', x)) #separate emojis
df = df.reset_index()
In [ ]:
class TweetsDataset(Dataset):
    def __init__(self, df):
        df["text"] = df["text"].apply(lambda x : re.sub(r'http\S+', '', x))
        df["text"] = df["text"].astype(str)
        self.frame = df

        self.tokenizer = ByteLevelBPETokenizer(end_of_word_suffix="</w>")
        self.tokenizer.normalizer = normalizers.Sequence([NFKD(), StripAccents(), Lowercase()])
        self.tokenizer.train_from_iterator(iter(self.frame["text"]), min_frequency=2, special_tokens=[
            "<s>",
            "<pad>",
            "</s>",
            "<unk>",
            "</w>",
        ])
        self.maxlen = 35
        for ix, row in df.iterrows():
          enc = self.tokenizer.encode(row["text"])
          curlen = len(enc.ids)
          if curlen > self.maxlen:
            self.frame.drop(ix, inplace=True)
        self.tokenizer.enable_padding(pad_id=1, pad_token="<pad>", length=self.maxlen+1)
        self.tokenizer.save_model(".", "twittertok")
        self.frame = self.frame.reset_index()

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        raw_tweet = self.frame.at[idx, "text"]
        encoded_obj = self.tokenizer.encode(raw_tweet, is_pretokenized=False)
        ids = encoded_obj.ids
        first_pad_ix = ids.index(1)
        ids = [0] + ids[:first_pad_ix] + [2] + ids[first_pad_ix:]
        unpadded_toks = [x for x in ids if x != 1]
        sample = {'ids': torch.tensor(ids).long(), 
                  'len':len(unpadded_toks)}
        return sample
In [ ]:
td = TweetsDataset(df)
n_train, n_val, n_test = int(round(len(td)*.75)), int(round(len(td)*.15)), int(round(len(td)*.10))
train_dataset, val_dataset, test_dataset = random_split(td, [n_train, n_val, n_test])

Model Architecture + Parameters

Encoder

In [ ]:
class Encoder(nn.Module):
    def __init__(self, vocab_size=256, 
                       embed_size=256, 
                       hidden_size=256, 
                       nhead=8, 
                       transformer_layers=8,
                       rnn_layers = 3,
                       dropout=0.15,
                       latent_dims=16):
        super(Encoder, self).__init__()
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 
        self.rnn_layers = rnn_layers
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(input_size=embed_size, 
                          hidden_size=hidden_size, 
                          num_layers=rnn_layers, 
                          dropout=dropout, 
                          bidirectional=True, 
                          batch_first=True)
        self.mu = nn.Linear(in_features=2*hidden_size, 
                            out_features=latent_dims)
        self.logvar = nn.Linear(in_features=2*hidden_size, 
                                out_features=latent_dims)
            
    def forward(self, x, lengths):
        batch_size = x.shape[0]
        x = self.transformer_encoder(x) #keeps dimensionality

        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True,
                                  enforce_sorted=False)
        outputs, hidden = self.rnn(packed) # output is (batch_size x sent_len x 2*hidden_size), hidden is (2*n_layers x batch_size x hidden_size)
        #no need to unpack, only using hiddens

        last_hidden = hidden.view(self.rnn_layers, 2, batch_size, self.hidden_size) # 2 for bidirectional... now (n_layers x 2 x batch_size x hidden_size)
        last_hidden_f = torch.squeeze(last_hidden[-1, 0, :, :]) # (batch_size x hidden_size)
        last_hidden_b = torch.squeeze(last_hidden[-1, 1, :, :]) # (batch_size x hidden_size)
        hid = torch.cat([last_hidden_f, last_hidden_b], dim=-1) # (batch_size x 2*hidden_size)

        x_mu = self.mu(hid) # (batch_size x latent_dims)
        x_logvar = self.logvar(hid) # (batch_size x latent_dims)
        return x_mu, x_logvar

Decoder

In [ ]:
class Decoder(nn.Module):
  def __init__(self, embedder,
                     embed_size=128,
                     hidden_size=256,
                     rnn_layers=3,
                     dropout=0.15,
                     vocab_size=1000
                     ):
    super(Decoder, self).__init__()
    
    self.rnn = nn.GRU(input_size=embed_size, 
                      hidden_size=hidden_size, 
                      num_layers=rnn_layers, 
                      batch_first=True, 
                      dropout=dropout, 
                      bidirectional=False)
    
    self.embedding = embedder

    #linear going from hidden size to the vocab size
    self.hidden_to_vocab = nn.Linear(hidden_size, vocab_size)
  
  def nucleus_sampling(self, outputs, top_p=0.5):
    sm = nn.Softmax(dim=-1)
    #initially receive (batch_size, 1, vocab_size) probs 
    logits = torch.squeeze(outputs, dim=1)
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) #batch_size, vocab_size
    sorted_cumulative_probs = torch.cumsum(sm(sorted_logits), -1) #diff from sorted_probs
    
    nucleus = torch.where(sorted_cumulative_probs < top_p, sorted_logits, torch.tensor(0, dtype=sorted_logits.dtype))
    unsorted_nucleus = nucleus.gather(1, sorted_indices.argsort(1))
    probs = sm(unsorted_nucleus) #batch_size, vocab_size
    selected_tokens = torch.multinomial(probs, 1).view(outputs.shape[0], 1) #(batch_size,1)

    return selected_tokens

  def forward(self, init_hiddens, ground_truth=None):
    """Unroll the decoder one step at a time.

    Inputs:
      - `init_hiddens`: a 3d-tensor of shape
          (n_layers, batch_size, hidden_size) representing the final
          encoder hidden states used to initialize the decoder hidden
          states.
      - `ground_truth`: a 3d-tensor of shape (batch_size, max_seq_length, embed_size)
          representing a batch of padded word vectors of target sentences [ONLY IF TEACHER FORCING]

    Returns:
      - `pre_output_vectors`: a 3d-tensor of shape
          (batch_size, max_len, hidden_size) representing the raw decoder
          outputs (before mapping to a `trg_vocab_size`-dim vector).
    """
    num_enc_layers, batch_size, hidden_size = init_hiddens.shape
    
    output_vectors = []
    sequence = []
    hidden = init_hiddens
    cur_input = torch.zeros((batch_size,1)).long()
    for i in range(td.maxlen+3): #one for the extra pad, one for sos token, one for eos token
      if ground_truth is not None: #teacher forcing
        cur_input = torch.unsqueeze(ground_truth[:,i,:], 1) #(batch_size, 1, embed_size)

      else: #this is at test time... we embed the input we have 
        cur_input = self.embedding(cur_input) # (batch_size, 1, embed_size)

      pre_output, hidden = self.rnn(cur_input, hidden)  #pre-output is (batch_size, 1, hidden_size)
      output = self.hidden_to_vocab(pre_output) #output is (batch_size, 1, vocab_size)

      if ground_truth is None: #if we dont have ground truth (testing/validation), then select word w/ nucleus sampling
        cur_input = self.nucleus_sampling(output) # (batch_size,1)
      
      output_vectors.append(output) 
      sequence.append(cur_input)

    output_vectors = torch.cat(output_vectors, dim=1)
    token_sequence = torch.cat(sequence, dim=-1)
    return output_vectors, token_sequence

VAE Loss

In [ ]:
def vae_loss(recon_x, x, mu, logvar, variational_beta, vocab_size):
    batch_size = x.shape[0]
    # recon_x is the probability of a multivariate Bernoulli distribution p.
    # Averaging or not averaging the binary cross-entropy over all pixels here
    # is a subtle detail with big effect on training, since it changes the weight
    # we need to pick for the other loss term by several orders of magnitude.
    # Not averaging is the direct implementation of the negative log likelihood,
    # but averaging makes the weight of the other loss term independent of the image resolution.
    loss = nn.CrossEntropyLoss(reduction='sum', ignore_index=1)
    logit = recon_x.view(-1, vocab_size) 
    view_x = x.contiguous().view(-1)
    recon_loss = loss(logit, view_x)
    
    # KL-divergence between the prior distribution over latent vectors
    # (the one we are going to sample from when generating new images)
    # and the distribution estimated by the generator for the given image.
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return (recon_loss + variational_beta*KLD)/batch_size

Combined VAE

In [ ]:
class VariationalAutoEncoder(pl.LightningModule):
    def __init__(self, vocab_size=256, 
                       embed_size=256, 
                       hidden_size=256, 
                       nhead=8, 
                       transformer_layers=8,
                       rnn_layers = 3,
                       dropout=0.15,
                       latent_dims=16,
                       variational_beta=10,
                       batch_size=16,
                       lr=2e-4):
        super().__init__()
        self.save_hyperparameters()

        #combination of Cross-entropy loss for reconstruction and KLDivLoss for variational stability 
        self.loss = vae_loss
        self.variational_beta = variational_beta

        #some layers to find good embeddings and turn the latent reepresentation into teh right dimensions
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=1)
        self.latent_to_hidden = nn.Linear(latent_dims, hidden_size)   

        #encoder: for turning the input into a mu and a sigma to get a latent representation
        self.encoder = Encoder(vocab_size=vocab_size, 
                               embed_size=embed_size, 
                               hidden_size=hidden_size, 
                               nhead=nhead, 
                               transformer_layers=transformer_layers,
                               rnn_layers = rnn_layers,
                               dropout=dropout,
                               latent_dims=latent_dims)
        
        #decoder: for decoding of latent rpr back into a sequence
        self.decoder = Decoder(embed_size=embed_size,
                               hidden_size=hidden_size,
                               rnn_layers=rnn_layers,
                               dropout=dropout,
                               embedder=self.embedding,
                               vocab_size=vocab_size) 

    def forward(self, n=1, z=None):
        if not z:
          z = torch.randn(n, self.hparams.latent_dims)

        #decoder
        z = self.latent_to_hidden(z) #n_batch x hidden_size
        z = torch.unsqueeze(z, dim=0).repeat(self.hparams.rnn_layers, 1, 1) #num_layers x n_batch x hidden_size
        x_hat, sequence = self.decoder(z) #(batch_size, max_len, hidden_size)

        return sequence

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, lens = batch["ids"], batch["len"] # n_batch x sent_len
        x_embedded = self.embedding(x) # n_batch x sent_len x embed_dims

        #run through encoder
        mu, logvar = self.encoder(x_embedded, lens)

        #reparametrization trick
        z = self.reparametrize(mu, logvar)

        #decoder
        z = self.latent_to_hidden(z) #n_batch x hidden_size
        z = torch.unsqueeze(z, dim=0).repeat(self.hparams.rnn_layers, 1, 1) #num_layers x n_batch x hidden_size
        x_hat, _ = self.decoder(z, ground_truth=x_embedded) #(batch_size, max_len, vocab_size)

        #eval
        loss = self.loss(x_hat, x, mu, logvar, self.variational_beta, self.hparams.vocab_size)

        # Logging to TensorBoard by default
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, lens = batch["ids"], batch["len"] # n_batch x sent_len
        x_embedded = self.embedding(x) # n_batch x sent_len x embed_dims

        #run through encoder
        mu, logvar = self.encoder(x_embedded, lens)

        #reparametrization trick
        z = self.reparametrize(mu, logvar)

        #decoder
        z = self.latent_to_hidden(z) #n_batch x hidden_size
        z = torch.unsqueeze(z, dim=0).repeat(self.hparams.rnn_layers, 1, 1) #num_layers x n_batch x hidden_size
        x_hat, _ = self.decoder(z, ground_truth=x_embedded) #(batch_size, max_len, vocab_size)

        #eval
        loss = self.loss(x_hat, x, mu, logvar, self.variational_beta, self.hparams.vocab_size)

        # Logging to TensorBoard by default
        self.log("val_loss", loss, on_step=True, on_epoch=True, logger=True)
        return loss
    
    def reparametrize(self, mu, log_var):
        """you generate a random distribution w.r.t. the mu and log_var from the embedding space.
        In order for the back-propagation to work, we need to be able to calculate the gradient. 
        This reparameterization trick first generates a normal distribution, then shapes the distribution
        with the mu and variance from the encoder.
        
        This way, we can can calculate the gradient parameterized by this particular random instance.
        """
        eps = Variable(torch.randn(mu.shape[0], self.hparams.latent_dims).type_as(mu))
        std = log_var.mul(0.5).exp_()   
        return eps.mul(std).add_(mu)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer

Training

In [ ]:
seed_everything(42, workers=True)

trainer = Trainer(check_val_every_n_epoch=1,
                  fast_dev_run=False, #true to "unit test" 
                  max_epochs=50,
                  precision=16,
                  profiler="simple",
                  gpus=1,
                  log_every_n_steps=20,
                  limit_train_batches=1.0,
                  limit_test_batches=1.0,
                  limit_val_batches=1.0,
                  max_time="00:02:30:00")

model = VariationalAutoEncoder(vocab_size=td.tokenizer.get_vocab_size(with_added_tokens=True), 
                               embed_size=128, 
                               hidden_size=256, 
                               nhead=8, 
                               transformer_layers=6,
                               rnn_layers = 3,
                               dropout=0.2,
                               latent_dims=16,
                               variational_beta=10)

train_loader = DataLoader(train_dataset, batch_size=model.hparams.batch_size, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=model.hparams.batch_size, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=model.hparams.batch_size, num_workers=8, pin_memory=True)

trainer.tune(model)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
Global seed set to 42
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type      | Params
-----------------------------------------------
0 | embedding        | Embedding | 3.8 M 
1 | latent_to_hidden | Linear    | 4.4 K 
2 | encoder          | Encoder   | 6.5 M 
3 | decoder          | Decoder   | 12.6 M
-----------------------------------------------
19.2 M    Trainable params
0         Non-trainable params
19.2 M    Total params
38.346    Total estimated model params size (MB)
Global seed set to 42
Time limit reached. Elapsed time is 2:30:00. Signaling Trainer to stop.
FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  9042.5         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  602.48         	|15             	|  9037.2         	|  99.942         	|
run_training_batch                 	|  0.12534        	|66520          	|  8337.9         	|  92.208         	|
optimizer_step_with_closure_0      	|  0.12386        	|66520          	|  8239.3         	|  91.118         	|
training_step_and_backward         	|  0.11461        	|66520          	|  7623.6         	|  84.309         	|
backward                           	|  0.069577       	|66520          	|  4628.3         	|  51.183         	|
model_forward                      	|  0.04393        	|66520          	|  2922.2         	|  32.317         	|
training_step                      	|  0.043748       	|66520          	|  2910.1         	|  32.182         	|
evaluation_step_and_end            	|  0.036913       	|13742          	|  507.26         	|  5.6098         	|
validation_step                    	|  0.036832       	|13742          	|  506.14         	|  5.5974         	|
zero_grad                          	|  0.0010618      	|66520          	|  70.633         	|  0.78112        	|
on_train_batch_end                 	|  0.00042176     	|66520          	|  28.055         	|  0.31026        	|
training_batch_to_device           	|  0.00033424     	|66520          	|  22.233         	|  0.24588        	|
on_train_epoch_end                 	|  1.2338         	|15             	|  18.507         	|  0.20466        	|
get_train_batch                    	|  0.00020163     	|66535          	|  13.416         	|  0.14836        	|
on_train_batch_start               	|  0.00019669     	|66520          	|  13.084         	|  0.14469        	|
fetch_next_train_batch             	|  0.00017846     	|66535          	|  11.874         	|  0.13131        	|
get_validate_batch                 	|  0.00030666     	|13755          	|  4.218          	|  0.046647       	|
evaluation_batch_to_device         	|  0.000288       	|13742          	|  3.9578         	|  0.043769       	|
fetch_next_validate_batch          	|  0.00028377     	|13755          	|  3.9032         	|  0.043165       	|
on_validation_batch_end            	|  0.00018759     	|13742          	|  2.5779         	|  0.028508       	|
on_after_backward                  	|  3.1708e-05     	|66520          	|  2.1092         	|  0.023326       	|
on_batch_start                     	|  2.7847e-05     	|66520          	|  1.8524         	|  0.020485       	|
on_before_optimizer_step           	|  2.37e-05       	|66520          	|  1.5766         	|  0.017435       	|
on_batch_end                       	|  2.3199e-05     	|66520          	|  1.5432         	|  0.017066       	|
on_before_zero_grad                	|  2.2646e-05     	|66520          	|  1.5064         	|  0.016659       	|
on_before_backward                 	|  2.1701e-05     	|66520          	|  1.4435         	|  0.015964       	|
training_step_end                  	|  7.9252e-06     	|66520          	|  0.52718        	|  0.0058301      	|
on_validation_start                	|  0.028537       	|16             	|  0.45659        	|  0.0050494      	|
on_validation_batch_start          	|  2.3208e-05     	|13742          	|  0.31892        	|  0.0035269      	|
validation_step_end                	|  8.1622e-06     	|13742          	|  0.11217        	|  0.0012404      	|
on_train_start                     	|  0.038581       	|1              	|  0.038581       	|  0.00042667     	|
on_train_epoch_start               	|  0.0023525      	|15             	|  0.035287       	|  0.00039023     	|
on_validation_end                  	|  0.002137       	|16             	|  0.034192       	|  0.00037813     	|
on_sanity_check_start              	|  0.027416       	|1              	|  0.027416       	|  0.0003032      	|
get_sanity_check_batch             	|  0.0068383      	|3              	|  0.020515       	|  0.00022687     	|
fetch_next_sanity_check_batch      	|  0.0067918      	|3              	|  0.020375       	|  0.00022533     	|
on_pretrain_routine_start          	|  0.0091581      	|1              	|  0.0091581      	|  0.00010128     	|
on_validation_model_eval           	|  0.00045034     	|16             	|  0.0072054      	|  7.9684e-05     	|
on_train_end                       	|  0.0013276      	|1              	|  0.0013276      	|  1.4681e-05     	|
on_epoch_start                     	|  2.6267e-05     	|31             	|  0.00081426     	|  9.0049e-06     	|
on_epoch_end                       	|  2.2478e-05     	|31             	|  0.00069682     	|  7.7061e-06     	|
configure_optimizers               	|  0.00053451     	|1              	|  0.00053451     	|  5.9111e-06     	|
on_validation_epoch_end            	|  2.9137e-05     	|16             	|  0.00046618     	|  5.1555e-06     	|
on_validation_epoch_start          	|  2.3956e-05     	|16             	|  0.00038329     	|  4.2388e-06     	|
on_sanity_check_end                	|  3.6545e-05     	|1              	|  3.6545e-05     	|  4.0415e-07     	|
on_pretrain_routine_end            	|  3.6033e-05     	|1              	|  3.6033e-05     	|  3.9849e-07     	|
on_fit_end                         	|  3.323e-05      	|1              	|  3.323e-05      	|  3.6749e-07     	|
on_before_accelerator_backend_setup	|  3.2395e-05     	|1              	|  3.2395e-05     	|  3.5825e-07     	|
on_configure_sharded_model         	|  2.7212e-05     	|1              	|  2.7212e-05     	|  3.0094e-07     	|
on_fit_start                       	|  2.5118e-05     	|1              	|  2.5118e-05     	|  2.7778e-07     	|
teardown                           	|  1.8653e-05     	|1              	|  1.8653e-05     	|  2.0628e-07     	|
setup                              	|  1.5147e-05     	|1              	|  1.5147e-05     	|  1.6751e-07     	|
configure_sharded_model            	|  7.582e-06      	|1              	|  7.582e-06      	|  8.3849e-08     	|
on_train_dataloader                	|  7.083e-06      	|1              	|  7.083e-06      	|  7.833e-08      	|
on_val_dataloader                  	|  7.077e-06      	|1              	|  7.077e-06      	|  7.8264e-08     	|
prepare_data                       	|  6.97e-06       	|1              	|  6.97e-06       	|  7.7081e-08     	|
configure_callbacks                	|  3.906e-06      	|1              	|  3.906e-06      	|  4.3196e-08     	|

In [ ]:
td.tokenizer.decode(list(model.forward(n=1)[0]))
Out[ ]:
' tsaxofauxorld xbox dropped manipul😂😂 kneeaaayeboitaereviv🤫🤫🤫🤫ansas bish keysother agrees gotta ✊🏽 consealph followin polo 🙏🏽🙏🏽🙏🏽dicckussoyuming wonderantheress reboot neg saintgatorsafterdarkcrackdai astrow'
In [ ]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Evaluation + Testing