# Imports
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, jit, grad
from flax import linen as nn
from nltk.tokenize import word_tokenize, sent_tokenize
from collections import Counter
import optax
import plotly.graph_objects as go
import plotly.io as pio
import umap
from figure_template import add_figure_templatesOverview
A deep dive on the Global Vectors for Word Representation (GloVe) algorithm with an implementation in JAX.
Introduction
In this short post, we’ll do a deep dive on the Global Vectors for Word Representation (GloVe) algorithm. Like Skip-Gram with Negative Sampling (SGNS), GloVe produces dense, static embeddings. These embeddings have a fixed dimension (commonly 300), and each token in the vocabulary is assigned a single embedding, irrespective of multiple word senses.
GloVe learns word relationships from global co-occurrence statistics. To begin, we construct a co-occurrence matrix over our corpus, which in this case is the same dataset we used in the Skip-Gram With Negative Sampling post.
The core idea is straightforward: the model minimizes the weighted squared error between the logarithm of observed co-occurrences and their predicted values. Let’s dive in!
Preprocessing Utilities
We are going to use the same preprocessing functions (tokenization, tokens-to-index, etc…) that we used when we went over the Skip-Gram With Negative Sampling algorithm. These are shown again below for your reference.
def tokenize_and_filter(text: str, min_occurrence: int) -> list[str]:
"""
Tokenizes input text and filters out tokens that occur less than or equal to `min_occurrence` times.
Parameters
----------
text: str
Input text to be tokenized.
min_occurrence: int
Minimum number of times a token must appear in the text to be included.
Returns
-------
list[str]:
A list of tokens that occur more than `min_occurrence` times.
"""
# Convert text to lowercase and tokenize
tokens = word_tokenize(text.lower().strip())
# Count occurrences of each token
token_counts = Counter(tokens)
# Filter tokens based on minimum occurrence threshold
filtered_tokens = [token for token in tokens if token_counts[token] > min_occurrence]
return filtered_tokensdef build_word_index(text: str, min_occurence: int = 0) -> tuple[dict[str, int], list[str]]:
"""
Builds a word-to-index mapping (vocabulary) from input text, including a special <unk> token.
Parameters
----------
text: str
Input text to extract vocabulary from.
min_occurrence: int
Minimum number of times a token must appear in the text to be included.
Returns
-------
Tuple[Dict[str, int], List[str]]:
- word_to_index: A dictionary mapping each word to a unique integer ID.
- vocabulary: A sorted list of unique vocabulary words including "<unk>".
"""
# Tokenize and get unique words
unique_tokens = set(tokenize_and_filter(text, min_occurrence=min_occurence))
# Add special token for unknown words
vocabulary = sorted(unique_tokens) + ["<unk>"]
# Create word-to-index mapping
word_to_index = {word: idx for idx, word in enumerate(vocabulary)}
return word_to_index, vocabularydef tokens_to_ids(text: str, word_to_index: dict[str, int], min_occurence: int = 0) -> list[int]:
"""
Converts a text string into a list of integer token IDs using a provided word-to-index mapping.
Unknown tokens are mapped to the ID of the "<unk>" token.
Parameters
----------
text: str
Input text to convert.
word_to_index: dict[str, int]
A dictionary mapping words to their corresponding integer IDs.
min_occurrence: int
Minimum number of times a token must appear in the text to be included.
Returns
-------
list[int]:
A list of integer IDs representing the tokenized input text.
"""
tokens = tokenize_and_filter(text, min_occurrence=min_occurence)
token_ids = [word_to_index.get(token, word_to_index["<unk>"]) for token in tokens]
return token_idsGlobal Vectors for Word Representation (GloVe)
Conceptually, GloVe is simpler than Skip-Gram with Negative Sampling (SGNS) because it avoids negative sampling and works directly with co-occurrence statistics. However, in practice SGNS can be easier to scale to very large corpora, since GloVe requires building and storing a co-occurrence matrix in memory.
That being said let’s take a look at how we can generate our co-occurrence matrix.
Co-occurence Matrices
In addition to creating the matrix, we need a utility that uses this matrix and returns a tuple of row indices, column indices, and co-occurrence counts. This is how we will pass the information to our loss function.
def build_cooc_matrix(corpus_tokens: list[str], vocabulary: list[str], word_to_index: dict[str, int], window_size: int = 5) -> np.ndarray:
"""
Build a co-occurrence matrix of a specified window size.
Parameters
----------
corpus_tokens: list[str]
tokenized text corpus
vocabulary: list[str]
sorted vocabulary
word_to_index: dict[str, int]
mapping from word to index in vocab
window_size: int
number of words to consider on each side
Returns
-------
np.ndarray
cooc of shape (vocabulary_size, vocabulary_size)
"""
vocab_size = len(vocabulary)
cooc_matrix = np.zeros((vocab_size, vocab_size), dtype=np.float32)
for idx, word in enumerate(corpus_tokens):
if word not in word_to_index:
continue # skip OOV words
i = word_to_index[word]
# Context window boundaries
start = max(0, idx - window_size)
end = min(len(corpus_tokens), idx + window_size + 1)
for j in range(start, end):
if j == idx:
continue # skip the word itself
context_word = corpus_tokens[j]
if context_word not in word_to_index:
continue
k = word_to_index[context_word]
distance = abs(j - idx)
cooc_matrix[i, k] += 1.0 / distance # inverse distance weighting
return cooc_matrix
def cooc_matrix_to_arrays(cooc_matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Convert a dense co-occurrence matrix to separate arrays of indices and counts.
Parameters
----------
cooc_matrix: np.ndarray
cooccurrence matrix
Returns
-------
i_indices: np.ndarray[int32]
j_indices: np.ndarray[int32]
counts: np.ndarray[float32]
"""
i_indices, j_indices = np.nonzero(cooc_matrix)
counts = cooc_matrix[i_indices, j_indices]
return (
i_indices.astype(np.int32),
j_indices.astype(np.int32),
counts.astype(np.float32),
)Loss Function
The loss function is straightforward: we approximate the logarithm of co-occurrence counts. The model uses two embedding matrices (for target and context words) and a bias term for each. The predicted value is the dot product of the embeddings plus the biases, and we minimize the squared error between this prediction and the true log co-occurrence. Importantly, the error is weighted by a function of the co-occurrence count, which downweights very rare pairs and caps the influence of very frequent ones.
More formally the loss function is: \[ J = \sum_{i,j}f(X_{ij})(w_{i}^{T}u_{j} + b_{i} + b_{j} - logX_{ij})^{2} \]
Where \(f()\) is the weighting function, \(w\) and \(u\) are the embeddings matrices, and \(b\)’s are the biases.
def weighting_fn(x: np.ndarray, xmax: int = 100, alpha: float = 0.75):
return jnp.where(x < xmax, (x / xmax) ** alpha, 1.0)
def glove_loss(params: tuple[jnp.ndarray], cooc_i: jnp.ndarray, cooc_j: jnp.ndarray, X_ij: jnp.ndarray):
W, W_tilde, b, b_tilde = params
wi = W[cooc_i]
wj = W_tilde[cooc_j]
bi = b[cooc_i]
bj = b_tilde[cooc_j]
pred = jnp.sum(wi * wj, axis=1) + bi + bj
logX = jnp.log(jnp.maximum(X_ij, 1e-10))
fX = weighting_fn(X_ij)
loss = jnp.sum(fX * (pred - logX) ** 2)
return lossLearning Embeddings
Now, for the hands-on portion, we are going to use the same dataset that we compiled in the previous post. We are going to load the data and create our tokens, indexes, and tokenize sentences exactly as we did previously. However, this time around we are going to pre-compute the co-occurence matrix and structure the elements of the matrix into row indexes, column indexes, and counts.
Load and Process Dataset
with open('./political_corpus.txt', 'r', encoding='utf-8') as file:
text = file.read()raw_tokens = tokenize_and_filter(text, min_occurrence=10)
word_to_index, vocabulary = build_word_index(text=text, min_occurence=10)
index_to_word = {v:k for k,v in word_to_index.items()}
cooc_matrix = build_cooc_matrix(raw_tokens, vocabulary, word_to_index, 10)
i_indices, j_indices, counts = cooc_matrix_to_arrays(cooc_matrix)Define Training Loop
Our training loop is very similar to what we had in the previous post, however, we wrap everything up in functions for a cleaner representation.
def nearest_neighbors(word: str, vectors: jnp.ndarray, k:int = 3) -> list[str]:
"""
Finds k nearest neighbors to the given word using cosine similarity.
Parameters
----------
word: str
Target word as a string.
vectors:
Trained Embeddings
k: int
Number of nearest neighbors to return
Returns
-------
List[str]
Nearest neighbor tokens to word token
"""
index = tokens_to_ids(word, word_to_index)
query = vectors[index[0]]
# Normalize vectors and query
vectors_norm = vectors / jnp.linalg.norm(vectors, axis=1, keepdims=True)
query_norm = query / jnp.linalg.norm(query)
# Compute cosine similarities
similarities = jnp.dot(vectors_norm, query_norm)
# Get top 10 indices (excluding the word itself)
top_indices = jnp.argsort(similarities)[::-1][1:k+1]
# Map back to words
return [index_to_word[idx] for idx in np.array(top_indices)]def log_step(epoch, params, total_loss, n_samples):
"""
Simple function to print out summaries after each epoch
"""
print(f"Epoch {epoch+1}, Avg Loss: {total_loss / n_samples:.4f}")
# Log embeddings!
embeddings = params[0] + params[1]
print('\nLearned embeddings:')
print(f'word: "tyranny" neighbors: {nearest_neighbors("tyranny", embeddings)}\n')
def train_glove(
i_indices: np.ndarray,
j_indices: np.ndarray,
counts: np.ndarray,
vocab_size: int,
dim: int = 300,
epochs: int = 50,
lr: float = 1e-3,
batch_size: int = 512
):
"""
"""
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (vocab_size, dim)) * 0.01
W_tilde = jax.random.normal(key, (vocab_size, dim)) * 0.01
b = jnp.zeros(vocab_size)
b_tilde = jnp.zeros(vocab_size)
params = (W, W_tilde, b, b_tilde)
opt = optax.adam(lr)
opt_state = opt.init(params)
n_samples = len(counts)
@jax.jit
def update(params: tuple[jnp.ndarray], opt_state, cooc_i: jnp.ndarray, cooc_j: jnp.ndarray, X_ij: jnp.ndarray):
loss, grads = jax.value_and_grad(glove_loss)(params, cooc_i, cooc_j, X_ij)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
for epoch in range(epochs):
total_loss = 0.0
# shuffle each epoch
perm = np.random.permutation(n_samples)
for start in range(0, n_samples, batch_size):
end = min(start + batch_size, n_samples)
batch_idx = perm[start:end]
cooc_i = jnp.array(i_indices[batch_idx])
cooc_j = jnp.array(j_indices[batch_idx])
X_ij = jnp.array(counts[batch_idx])
params, opt_state, loss = update(params, opt_state, cooc_i, cooc_j, X_ij)
total_loss += float(loss) * len(batch_idx)
log_step(epoch, params, total_loss, n_samples)
return paramsTrain
Alright, let’s run through our training loop!!
trained_params = train_glove(
i_indices,
j_indices,
counts,
len(vocabulary),
)
embeddings = trained_params[0] + trained_params[1]Epoch 1, Avg Loss: 30.8074
Learned embeddings:
word: "tyranny" neighbors: ['choice', 'authorities', 'defects']
Epoch 2, Avg Loss: 12.0077
Learned embeddings:
word: "tyranny" neighbors: ['defects', 'encroachments', 'organization']
Epoch 3, Avg Loss: 9.5171
Learned embeddings:
word: "tyranny" neighbors: ['appointing', 'judicial', 'physical']
Epoch 4, Avg Loss: 7.1958
Learned embeddings:
word: "tyranny" neighbors: ['causes', 'principal', 'judicial']
Epoch 5, Avg Loss: 5.5379
Learned embeddings:
word: "tyranny" neighbors: ['causes', 'arbitrary', 'principal']
Epoch 6, Avg Loss: 4.3948
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'causes', 'iii']
Epoch 7, Avg Loss: 3.5873
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'causes', 'anarchy']
Epoch 8, Avg Loss: 2.9879
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'iii']
Epoch 9, Avg Loss: 2.5295
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'causes']
Epoch 10, Avg Loss: 2.1881
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'majority']
Epoch 11, Avg Loss: 1.9223
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'arbitrary', 'majority']
Epoch 12, Avg Loss: 1.7164
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'arbitrary', 'iii']
Epoch 13, Avg Loss: 1.5604
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'anarchy']
Epoch 14, Avg Loss: 1.4382
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'anarchy', 'administration']
Epoch 15, Avg Loss: 1.3376
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'anarchy', 'administration']
Epoch 16, Avg Loss: 1.2611
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'iii', 'administration']
Epoch 17, Avg Loss: 1.1893
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'anarchy', 'iii']
Epoch 18, Avg Loss: 1.1444
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'administration', 'iii']
Epoch 19, Avg Loss: 1.0996
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'administration', 'anarchy']
Epoch 20, Avg Loss: 1.0615
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'administration', 'iii']
Epoch 21, Avg Loss: 1.0228
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'arbitrary']
Epoch 22, Avg Loss: 1.0045
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'iii']
Epoch 23, Avg Loss: 0.9816
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'iii']
Epoch 24, Avg Loss: 0.9568
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'iii']
Epoch 25, Avg Loss: 0.9430
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'iii']
Epoch 26, Avg Loss: 0.9328
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'mass']
Epoch 27, Avg Loss: 0.9034
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'mass']
Epoch 28, Avg Loss: 0.9004
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'administration', 'mass']
Epoch 29, Avg Loss: 0.8919
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'administration', 'majority']
Epoch 30, Avg Loss: 0.8782
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'administration', 'mass']
Epoch 31, Avg Loss: 0.8742
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'administration', 'arbitrary']
Epoch 32, Avg Loss: 0.8549
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'mass']
Epoch 33, Avg Loss: 0.8482
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'iii', 'majority']
Epoch 34, Avg Loss: 0.8530
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'iii', 'mass']
Epoch 35, Avg Loss: 0.8392
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'majority', 'action']
Epoch 36, Avg Loss: 0.8356
Learned embeddings:
word: "tyranny" neighbors: ['administration', 'mass', 'action']
Epoch 37, Avg Loss: 0.8258
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'action', 'administration']
Epoch 38, Avg Loss: 0.8254
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'administration', 'majority']
Epoch 39, Avg Loss: 0.8133
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'action', 'administration']
Epoch 40, Avg Loss: 0.8121
Learned embeddings:
word: "tyranny" neighbors: ['action', 'mass', 'iii']
Epoch 41, Avg Loss: 0.8074
Learned embeddings:
word: "tyranny" neighbors: ['majority', 'mass', 'action']
Epoch 42, Avg Loss: 0.8145
Learned embeddings:
word: "tyranny" neighbors: ['action', 'administration', 'mass']
Epoch 43, Avg Loss: 0.7947
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'action', 'majority']
Epoch 44, Avg Loss: 0.8011
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'action', 'administration']
Epoch 45, Avg Loss: 0.7941
Learned embeddings:
word: "tyranny" neighbors: ['action', 'administration', 'iii']
Epoch 46, Avg Loss: 0.7883
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'majority', 'action']
Epoch 47, Avg Loss: 0.7911
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'action', 'arbitrary']
Epoch 48, Avg Loss: 0.7877
Learned embeddings:
word: "tyranny" neighbors: ['action', 'mass', 'administration']
Epoch 49, Avg Loss: 0.7827
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'administration', 'action']
Epoch 50, Avg Loss: 0.7800
Learned embeddings:
word: "tyranny" neighbors: ['action', 'mass', 'majority']
Inspect learned embeddings
We use the same nearest neighbors approach to inspect our learned embeddings, as we did in the previous post. However, we added a plotting utility to plot the nearest neighbors after reducing the vector embeddings dimensions using UMAP. It’s a little more exciting to have a visual, I suppose.
Looking at the results, below, it is apparent that our embeddings have learned word relationships based on the cooccurrences within the corpus.
def reduce_umap(vectors, n_components=2, normalize=True, n_neighbors=10, min_dist=0.1, seed=13265):
"""simple dimensionality reduction for plotting"""
arr = np.array(vectors)
if normalize:
arr = arr / np.linalg.norm(arr, axis=1, keepdims=True)
reducer = umap.UMAP(
n_components=n_components,
n_neighbors=n_neighbors,
min_dist=min_dist,
random_state=seed,
metric='cosine',
n_jobs=1
)
return reducer.fit_transform(arr)
def plot_neighbors(word: str, vectors: jnp.ndarray, k: int = 5, theme="light"):
"""
Create a 2D Plotly scatterplot of nearest neighbors for a given word.
Parameters
----------
word: str
Target word.
vectors: jnp.ndarray
Trained embedding matrix.
k: int
Number of neighbors.
theme: str
The plots theme (light, dark)
Returns
-------
plotly.graph_objects.Figure
"""
add_figure_templates()
if theme == "light":
template = pio.templates["mantine_light"]
else:
template = pio.templates["mantine_dark"]
# Get neighbors
neighbors = nearest_neighbors(word, vectors, k=k)
words = [word] + neighbors
# Get their embeddings
indices = np.array([tokens_to_ids(w, word_to_index)[0] for w in words])
emb = np.array(vectors[indices])
# Reduce to 2D
emb_2d = reduce_umap(emb, 2, True, k)
# plotly scatter plot
fig = go.Figure()
fig.add_trace(
go.Scatter(
x = emb_2d[:,0],
y = emb_2d[:,1],
mode="markers+text",
marker = dict(
color=["red"] + ["blue"]*k,
symbol=["star"] + ["circle"]*k,
size=14
),
text=words,
textposition="top center"
)
)
return fig.update_layout(
title = f"Nearest neighbors of '{word}'",
template=template
)plot_neighbors("pennsylvania", embeddings, 10, "light")plot_neighbors("pennsylvania", embeddings, 10, "dark")plot_neighbors("nobles", embeddings, 10, "light")plot_neighbors("nobles", embeddings, 10, "dark")plot_neighbors("war", embeddings, 10, "light")plot_neighbors("war", embeddings, 10, "dark")Conclusion
In this short post we did a deep dive on the GloVE algorithm. Specifically, we focused on how training examples are generated, how the loss function is defined, and implementing it with JAX.