LSTM model trained on wikitext-2 (pytorch)

6 minute read

Published:

LSTMs are widely used architecture in NLP. In this blog i am trying to explain, what is architecture, how i trained model and what are the results. This model predicts most probable next word in a sentence.

LSTM on wikitext-2

Here we will use wikitext-2-raw-v1 data from huggingface. Huggingface is github of ML. Here we will decide vocab size. And text-len upto which we want a word to look for. For example, we are predicting next word in sequence during training then this sequence length will be at max text-len after that we will start again from $h_{0}$ (hidden vector at the start usually all values are zeros). We tokenize words in this preprocessing part and store word to token as dictionary and also vice versa as a list.

from datasets import load_dataset
from collections import Counter
import torch
import random
vocab = 50000
text_len = 20
# Load Wikitext-103 dataset
dataset = load_dataset(path="wikitext", name='wikitext-2-raw-v1', split="train")
# Define tokenizer
tokenizer = lambda x: x.split()  # Simple tokenizer splitting by space
# Tokenize the text and count frequency
counter = Counter()
train_data=[]
for example in dataset:
    tokens = tokenizer(example["text"])
    if len(tokens)>=text_len:
      train_data.extend(tokens)
    counter.update(tokens)

# Select the 5000 most common words
most_common_words = counter.most_common(vocab-1)
word_to_index = {word: i for i, (word, _) in enumerate(most_common_words)}

# values = list(word_to_index.values())
# random.shuffle(values)
# for i,k in enumerate(word_to_index.keys()):
#   word_to_index[k] = values[i]


# Define a function to convert text to one-hot encoding
def text_to_one_hot(text):
    token = text
    index = word_to_index.get(token, vocab-1)
    return index

# Convert example text into one-hot encoding as a torch tensor
train_w_data = [text_to_one_hot(word) for word in train_data]
train_data = train_w_data
print(most_common_words[:100])
[('the', 113161), (',', 99913), ('.', 73388), ('of', 56889), ('and', 50603), ('in', 39453), ('to', 39190), ('a', 34237), ('=', 29570), ('"', 28309), ('was', 20985), ('The', 17602), ('@-@', 16906), ('that', 14135), ('as', 14021), ("'s", 14002), ('on', 13678), ('for', 13307), ('with', 12606), ('by', 12148), (')', 12004), ('(', 11992), ('is', 11637), ('from', 8941), ('his', 8395), ('at', 8186), ('were', 7330), ('it', 6748), ('he', 5947), ('an', 5895), ('had', 5698), ('which', 5543), ('In', 5525), ('be', 4787), ('are', 4669), (';', 4224), ('their', 4127), ('but', 4035), ('first', 3978), ('–', 3934), ('not', 3906), (':', 3806), ('also', 3746), ('its', 3697), ('or', 3636), ('her', 3486), ('have', 3442), ('one', 3351), ('has', 3322), ('been', 3257), ('two', 3238), ('@.@', 3194), ('this', 3079), ('who', 2963), ('they', 2941), ("'", 2870), ('He', 2759), ('@,@', 2699), ('after', 2640), ('It', 2522), ('time', 2451), ('into', 2440), ('other', 2433), ('would', 2320), ('more', 2319), ('1', 2253), ('all', 2100), ('when', 2036), ('I', 2026), ('over', 2015), ('2', 1982), ('she', 1981), ('during', 1973), ('only', 1948), ('A', 1919), ('game', 1846), ('up', 1843), ('him', 1840), ('about', 1833), ('three', 1803), ('most', 1778), ('between', 1768), ('than', 1759), ('out', 1707), ('while', 1684), ('later', 1682), ('season', 1641), ('made', 1635), ('3', 1601), ('such', 1532), ('year', 1507), ('used', 1505), ('new', 1501), ('where', 1490), ('them', 1481), ('This', 1481), ('some', 1469), ('On', 1462), ('000', 1447), ('song', 1444)]

Now we will define custom dataset class which is inherited from torch.utils.data.Dataset base class. This is simple and you can understand by code itself.

from torch.utils.data import Dataset

class Text_dataset(Dataset):

  def __init__(self, text_list, text_len):
    self.data= text_list
    self.text_len = text_len

  def __len__(self):
    return len(self.data)//self.text_len

  # This will
  def __getitem__(self,i):
    return torch.tensor(self.data[i*self.text_len:i*self.text_len+29]), torch.tensor(self.data[i*self.text_len+1:i*self.text_len+30])

Important notes other than simple LSTM layers. We are using dropout at fullyconnected layers to avoid overfitting.

import torch.nn as nn

class LSTM_Gen(nn.Module):
  def __init__(self,input_size=1000, hidden_size=1024, hidden_layer=1,embedding_size=512, batch_size=20):
    super(LSTM_Gen, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.hidden_layer = hidden_layer
    self.embedding_size = embedding_size
    self.dropout = nn.Dropout(p=0.5)
    self.embedding = torch.nn.Embedding(self.input_size, self.embedding_size)
    self.lstm = torch.nn.LSTM(self.embedding_size, self.hidden_size, self.hidden_layer, batch_first=True)
    self.fc1 = torch.nn.Linear(self.hidden_size,self.input_size)
    self.batch_size = batch_size


  def forward(self,x,hidden):
    x = self.embedding(x)
    x = self.dropout(x)
    x, h = self.lstm(x, hidden)
    out = self.fc1(self.dropout(x))
    return out, h

Now comes training part. Its the most time consuming part. Finding best hyperparmeters and debugging network is 90 percent of the work. you can see below all the hyperparameter. Saving models in between epochs is a good practice and it will save your time also.

from torch.utils.data import DataLoader
import torch.optim as optim

check_point = 19

batch_size = 30
hidden_size = 2048
hidden_layer = 2
embedding_size = 1024
epoch = 20
learning_rate = 0.001
clip = 0.25

if torch.cuda.is_available():
    # Set device to CUDA
    device = torch.device("cuda")
    print("CUDA is available! Using GPU for training.")
else:
    # Set device to CPU
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU for training.")
dset = Text_dataset(train_data,text_len)
model = LSTM_Gen(vocab,hidden_size,hidden_layer,embedding_size,batch_size)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
if check_point>=0:
  model.load_state_dict(torch.load(f'model_{check_point}.pth'))
model.train()
for e in range(check_point+1,epoch):
  train_loader = DataLoader(dset, batch_size=batch_size, shuffle=True)

  for i, data in enumerate(train_loader):
    optimizer.zero_grad()
    inp, label = data[0],data[1]
    inp = inp.to(device)
    label = label.to(device)
    out = None
    hidden = (torch.zeros(hidden_layer, inp.shape[0], hidden_size).to(device),torch.zeros(hidden_layer,inp.shape[0], hidden_size).to(device))
    out,_ = model(inp,hidden)
    out = out.view(-1,vocab)
    label = label.view(-1)
    loss = criterion(out, label)
    loss.backward()
    optimizer.step()

    if i%100==0:
      print(f'loss at epoch {e} and batch {i} is '+"Loss: {:.3f}".format(loss))
      print(torch.argmax(out[10]).item(),label[10].item())
  torch.save(model.state_dict(), f'model_{e}.pth')
CUDA is available! Using GPU for training.
model.eval()
# start = word_to_index['Germany']
start = 1500
input = torch.tensor([[start]]).to(device)
print(most_common_words[start][0],end=' ')
out = 0
hidden = (torch.zeros(hidden_layer, 1, hidden_size).to(device),torch.zeros(hidden_layer,1,hidden_size).to(device))
i=0
while( i!=100):
  out, hidden = model(input,hidden)
  out = torch.argmax(out[0][0]).item()
  input = torch.tensor([[out]]).to(device)
  if out!=vocab-1:
    print(most_common_words[out][0],end=' ')
  else:
    print('<unk>',end=' ')
  if (i+1)%15==0:
    print()
  i+=1
Smith , the <unk> <unk> , and the <unk> <unk> , which was serialised in the 
United States on November 30 , 2003 . The episode was watched by 8 @.@ 
8 million viewers , and the second single in Europe . The album was the 
first single from the album , and it was released on the album . It 
was released on July 26 , 2012 , and the album peaked at number one 
on the Billboard 200 . It was certified double platinum by the Recording Industry Association 
of America ( RIAA ) , denoting shipments of ten 

It doesn’t look impressive. But, we can see that model atleast learned how to write sentence and where to put ‘.’, ‘,’ etc given that we had small dataset and computational power. If we have large training data and computational power this would scale well and we probably would not be able to distinguish between human or model written.