# Imports
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, jit, grad
from flax import linen as nn
from nltk.tokenize import word_tokenize, sent_tokenize
from collections import Counter
import optax
import plotly.graph_objects as go
import plotly.io as pio
import umap
from figure_template import add_figure_templatesOverview
A deep dive on the Global Vectors for Word Representation (GloVe) algorithm with an implementation in JAX.
Introduction
In this short post, we’ll do a deep dive on the Global Vectors for Word Representation (GloVe) algorithm. Like Skip-Gram with Negative Sampling (SGNS), GloVe produces dense, static embeddings. These embeddings have a fixed dimension (commonly 300), and each token in the vocabulary is assigned a single embedding, irrespective of multiple word senses.
GloVe learns word relationships from global co-occurrence statistics. To begin, we construct a co-occurrence matrix over our corpus, which in this case is the same dataset we used in the Skip-Gram With Negative Sampling post.
The core idea is straightforward: the model minimizes the weighted squared error between the logarithm of observed co-occurrences and their predicted values. Let’s dive in!
Preprocessing Utilities
We are going to use the same preprocessing functions (tokenization, tokens-to-index, etc…) that we used when we went over the Skip-Gram With Negative Sampling algorithm. These are shown again below for your reference.
def tokenize_and_filter(text: str, min_occurrence: int) -> list[str]:
"""
Tokenizes input text and filters out tokens that occur less than or equal to `min_occurrence` times.
Parameters
----------
text: str
Input text to be tokenized.
min_occurrence: int
Minimum number of times a token must appear in the text to be included.
Returns
-------
list[str]:
A list of tokens that occur more than `min_occurrence` times.
"""
# Convert text to lowercase and tokenize
tokens = word_tokenize(text.lower().strip())
# Count occurrences of each token
token_counts = Counter(tokens)
# Filter tokens based on minimum occurrence threshold
filtered_tokens = [token for token in tokens if token_counts[token] > min_occurrence]
return filtered_tokensdef build_word_index(text: str, min_occurence: int = 0) -> tuple[dict[str, int], list[str]]:
"""
Builds a word-to-index mapping (vocabulary) from input text, including a special <unk> token.
Parameters
----------
text: str
Input text to extract vocabulary from.
min_occurrence: int
Minimum number of times a token must appear in the text to be included.
Returns
-------
Tuple[Dict[str, int], List[str]]:
- word_to_index: A dictionary mapping each word to a unique integer ID.
- vocabulary: A sorted list of unique vocabulary words including "<unk>".
"""
# Tokenize and get unique words
unique_tokens = set(tokenize_and_filter(text, min_occurrence=min_occurence))
# Add special token for unknown words
vocabulary = sorted(unique_tokens) + ["<unk>"]
# Create word-to-index mapping
word_to_index = {word: idx for idx, word in enumerate(vocabulary)}
return word_to_index, vocabularydef tokens_to_ids(text: str, word_to_index: dict[str, int], min_occurence: int = 0) -> list[int]:
"""
Converts a text string into a list of integer token IDs using a provided word-to-index mapping.
Unknown tokens are mapped to the ID of the "<unk>" token.
Parameters
----------
text: str
Input text to convert.
word_to_index: dict[str, int]
A dictionary mapping words to their corresponding integer IDs.
min_occurrence: int
Minimum number of times a token must appear in the text to be included.
Returns
-------
list[int]:
A list of integer IDs representing the tokenized input text.
"""
tokens = tokenize_and_filter(text, min_occurrence=min_occurence)
token_ids = [word_to_index.get(token, word_to_index["<unk>"]) for token in tokens]
return token_idsGlobal Vectors for Word Representation (GloVe)
Conceptually, GloVe is simpler than Skip-Gram with Negative Sampling (SGNS) because it avoids negative sampling and works directly with co-occurrence statistics. However, in practice SGNS can be easier to scale to very large corpora, since GloVe requires building and storing a co-occurrence matrix in memory.
That being said let’s take a look at how we can generate our co-occurrence matrix.
Co-occurence Matrices
In addition to creating the matrix, we need a utility that uses this matrix and returns a tuple of row indices, column indices, and co-occurrence counts. This is how we will pass the information to our loss function.
def build_cooc_matrix(corpus_tokens: list[str], vocabulary: list[str], word_to_index: dict[str, int], window_size: int = 5) -> np.ndarray:
"""
Build a co-occurrence matrix of a specified window size.
Parameters
----------
corpus_tokens: list[str]
tokenized text corpus
vocabulary: list[str]
sorted vocabulary
word_to_index: dict[str, int]
mapping from word to index in vocab
window_size: int
number of words to consider on each side
Returns
-------
np.ndarray
cooc of shape (vocabulary_size, vocabulary_size)
"""
vocab_size = len(vocabulary)
cooc_matrix = np.zeros((vocab_size, vocab_size), dtype=np.float32)
for idx, word in enumerate(corpus_tokens):
if word not in word_to_index:
continue # skip OOV words
i = word_to_index[word]
# Context window boundaries
start = max(0, idx - window_size)
end = min(len(corpus_tokens), idx + window_size + 1)
for j in range(start, end):
if j == idx:
continue # skip the word itself
context_word = corpus_tokens[j]
if context_word not in word_to_index:
continue
k = word_to_index[context_word]
distance = abs(j - idx)
cooc_matrix[i, k] += 1.0 / distance # inverse distance weighting
return cooc_matrix
def cooc_matrix_to_arrays(cooc_matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Convert a dense co-occurrence matrix to separate arrays of indices and counts.
Parameters
----------
cooc_matrix: np.ndarray
cooccurrence matrix
Returns
-------
i_indices: np.ndarray[int32]
j_indices: np.ndarray[int32]
counts: np.ndarray[float32]
"""
i_indices, j_indices = np.nonzero(cooc_matrix)
counts = cooc_matrix[i_indices, j_indices]
return (
i_indices.astype(np.int32),
j_indices.astype(np.int32),
counts.astype(np.float32),
)Loss Function
The loss function is straightforward: we approximate the logarithm of co-occurrence counts. The model uses two embedding matrices (for target and context words) and a bias term for each. The predicted value is the dot product of the embeddings plus the biases, and we minimize the squared error between this prediction and the true log co-occurrence. Importantly, the error is weighted by a function of the co-occurrence count, which downweights very rare pairs and caps the influence of very frequent ones.
More formally the loss function is: \[ J = \sum_{i,j}f(X_{ij})(w_{i}^{T}u_{j} + b_{i} + b_{j} - logX_{ij})^{2} \]
Where \(f()\) is the weighting function, \(w\) and \(u\) are the embeddings matrices, and \(b\)’s are the biases.
def weighting_fn(x: np.ndarray, xmax: int = 100, alpha: float = 0.75):
return jnp.where(x < xmax, (x / xmax) ** alpha, 1.0)
def glove_loss(params: tuple[jnp.ndarray], cooc_i: jnp.ndarray, cooc_j: jnp.ndarray, X_ij: jnp.ndarray):
W, W_tilde, b, b_tilde = params
wi = W[cooc_i]
wj = W_tilde[cooc_j]
bi = b[cooc_i]
bj = b_tilde[cooc_j]
pred = jnp.sum(wi * wj, axis=1) + bi + bj
logX = jnp.log(jnp.maximum(X_ij, 1e-10))
fX = weighting_fn(X_ij)
loss = jnp.sum(fX * (pred - logX) ** 2)
return lossLearning Embeddings
Now, for the hands-on portion, we are going to use the same dataset that we compiled in the previous post. We are going to load the data and create our tokens, indexes, and tokenize sentences exactly as we did previously. However, this time around we are going to pre-compute the co-occurence matrix and structure the elements of the matrix into row indexes, column indexes, and counts.
Load and Process Dataset
with open('./political_corpus.txt', 'r', encoding='utf-8') as file:
text = file.read()raw_tokens = tokenize_and_filter(text, min_occurrence=10)
word_to_index, vocabulary = build_word_index(text=text, min_occurence=10)
index_to_word = {v:k for k,v in word_to_index.items()}
sentences = sent_tokenize(text.lower().strip())
cooc_matrix = build_cooc_matrix(raw_tokens, vocabulary, word_to_index, 10)
i_indices, j_indices, counts = cooc_matrix_to_arrays(cooc_matrix)Define Training Loop
Our training loop is very similar to what we had in the previous post, however, we wrap everything up in functions for a cleaner representation.
def nearest_neighbors(word: str, vectors: jnp.ndarray, k:int = 3) -> list[str]:
"""
Finds k nearest neighbors to the given word using cosine similarity.
Parameters
----------
word: str
Target word as a string.
vectors:
Trained Embeddings
k: int
Number of nearest neighbors to return
Returns
-------
List[str]
Nearest neighbor tokens to word token
"""
index = tokens_to_ids(word, word_to_index)
query = vectors[index[0]]
# Normalize vectors and query
vectors_norm = vectors / jnp.linalg.norm(vectors, axis=1, keepdims=True)
query_norm = query / jnp.linalg.norm(query)
# Compute cosine similarities
similarities = jnp.dot(vectors_norm, query_norm)
# Get top 10 indices (excluding the word itself)
top_indices = jnp.argsort(similarities)[::-1][1:k+1]
# Map back to words
return [index_to_word[idx] for idx in np.array(top_indices)]def log_step(epoch, params, total_loss, n_samples):
"""
Simple function to print out summaries after each epoch
"""
print(f"Epoch {epoch+1}, Avg Loss: {total_loss / n_samples:.4f}")
# Log embeddings!
embeddings = params[0] + params[1]
print('\nLearned embeddings:')
print(f'word: "tyranny" neighbors: {nearest_neighbors("tyranny", embeddings)}\n')
def train_glove(
i_indices: np.ndarray,
j_indices: np.ndarray,
counts: np.ndarray,
vocab_size: int,
dim: int = 300,
epochs: int = 50,
lr: float = 1e-3,
batch_size: int = 512
):
"""
"""
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (vocab_size, dim)) * 0.01
W_tilde = jax.random.normal(key, (vocab_size, dim)) * 0.01
b = jnp.zeros(vocab_size)
b_tilde = jnp.zeros(vocab_size)
params = (W, W_tilde, b, b_tilde)
opt = optax.adam(lr)
opt_state = opt.init(params)
n_samples = len(counts)
@jax.jit
def update(params: tuple[jnp.ndarray], opt_state, cooc_i: jnp.ndarray, cooc_j: jnp.ndarray, X_ij: jnp.ndarray):
loss, grads = jax.value_and_grad(glove_loss)(params, cooc_i, cooc_j, X_ij)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
for epoch in range(epochs):
total_loss = 0.0
# shuffle each epoch
perm = np.random.permutation(n_samples)
for start in range(0, n_samples, batch_size):
end = min(start + batch_size, n_samples)
batch_idx = perm[start:end]
cooc_i = jnp.array(i_indices[batch_idx])
cooc_j = jnp.array(j_indices[batch_idx])
X_ij = jnp.array(counts[batch_idx])
params, opt_state, loss = update(params, opt_state, cooc_i, cooc_j, X_ij)
total_loss += float(loss) * len(batch_idx)
log_step(epoch, params, total_loss, n_samples)
return paramsTrain
Alright, let’s run through our training loop!!
trained_params = train_glove(
i_indices,
j_indices,
counts,
len(vocabulary),
)
embeddings = trained_params[0] + trained_params[1]Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1757784911.179174 13105944 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
I0000 00:00:1757784911.193231 13105944 service.cc:145] XLA service 0x37a4db9f0 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757784911.193243 13105944 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1757784911.194202 13105944 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1757784911.194208 13105944 mps_client.cc:384] XLA backend will use up to 38654230528 bytes on device 0 for SimpleAllocator.
Metal device set to: Apple M4 Pro
Epoch 1, Avg Loss: 31.2775
Learned embeddings:
word: "tyranny" neighbors: ['justice', 'administration', 'view']
Epoch 2, Avg Loss: 12.0445
Learned embeddings:
word: "tyranny" neighbors: ['independence', 'election', 'justice']
Epoch 3, Avg Loss: 9.5596
Learned embeddings:
word: "tyranny" neighbors: ['justice', 'social', 'genius']
Epoch 4, Avg Loss: 7.2767
Learned embeddings:
word: "tyranny" neighbors: ['justice', 'maintenance', 'causes']
Epoch 5, Avg Loss: 5.5992
Learned embeddings:
word: "tyranny" neighbors: ['causes', 'justice', 'ii']
Epoch 6, Avg Loss: 4.4661
Learned embeddings:
word: "tyranny" neighbors: ['causes', 'iii', 'anarchy']
Epoch 7, Avg Loss: 3.6489
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'causes', 'anarchy']
Epoch 8, Avg Loss: 3.0393
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'iii']
Epoch 9, Avg Loss: 2.5749
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'ii']
Epoch 10, Avg Loss: 2.2202
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'ii']
Epoch 11, Avg Loss: 1.9369
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'majority']
Epoch 12, Avg Loss: 1.7404
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'states—part']
Epoch 13, Avg Loss: 1.5700
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'anarchy']
Epoch 14, Avg Loss: 1.4433
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'states—part', 'anarchy']
Epoch 15, Avg Loss: 1.3429
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'anarchy']
Epoch 16, Avg Loss: 1.2633
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'waters', 'anarchy']
Epoch 17, Avg Loss: 1.1945
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'majority']
Epoch 18, Avg Loss: 1.1495
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'majority', 'waters']
Epoch 19, Avg Loss: 1.0993
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'ii']
Epoch 20, Avg Loss: 1.0637
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'mass']
Epoch 21, Avg Loss: 1.0349
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'arbitrary']
Epoch 22, Avg Loss: 1.0054
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'anarchy']
Epoch 23, Avg Loss: 0.9782
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'arbitrary']
Epoch 24, Avg Loss: 0.9602
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'majority']
Epoch 25, Avg Loss: 0.9458
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'anarchy']
Epoch 26, Avg Loss: 0.9283
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'anarchy']
Epoch 27, Avg Loss: 0.9154
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'mass']
Epoch 28, Avg Loss: 0.8958
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'arbitrary', 'mass']
Epoch 29, Avg Loss: 0.8922
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'mass', 'arbitrary']
Epoch 30, Avg Loss: 0.8830
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'anarchy', 'administration']
Epoch 31, Avg Loss: 0.8703
Learned embeddings:
word: "tyranny" neighbors: ['waters', 'anarchy', 'mass']
Epoch 32, Avg Loss: 0.8605
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'anarchy', 'waters']
Epoch 33, Avg Loss: 0.8599
Learned embeddings:
word: "tyranny" neighbors: ['ii', 'mass', 'anarchy']
Epoch 34, Avg Loss: 0.8462
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'enforcing']
Epoch 35, Avg Loss: 0.8439
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'arbitrary', 'mass']
Epoch 36, Avg Loss: 0.8337
Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'mass', 'waters']
Epoch 37, Avg Loss: 0.8279
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'arbitrary', 'instrument']
Epoch 38, Avg Loss: 0.8299
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'arbitrary', 'instrument']
Epoch 39, Avg Loss: 0.8237
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'instrument']
Epoch 40, Avg Loss: 0.8101
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'waters']
Epoch 41, Avg Loss: 0.8165
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'instrument', 'analogy']
Epoch 42, Avg Loss: 0.8076
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'instrument']
Epoch 43, Avg Loss: 0.8014
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'enforcing']
Epoch 44, Avg Loss: 0.7988
Learned embeddings:
word: "tyranny" neighbors: ['conditions', 'mass', 'administration']
Epoch 45, Avg Loss: 0.7993
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'analogy', 'instrument']
Epoch 46, Avg Loss: 0.7960
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'instrument']
Epoch 47, Avg Loss: 0.7839
Learned embeddings:
word: "tyranny" neighbors: ['analogy', 'mass', 'instrument']
Epoch 48, Avg Loss: 0.7944
Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'instrument', 'mass']
Epoch 49, Avg Loss: 0.7813
Learned embeddings:
word: "tyranny" neighbors: ['mass', 'anarchy', 'analogy']
Epoch 50, Avg Loss: 0.7801
Learned embeddings:
word: "tyranny" neighbors: ['instrument', 'mass', 'anarchy']
Inspect learned embeddings
We use the same nearest neighbors approach to inspect our learned embeddings, as we did in the previous post. However, we added a plotting utility to plot the nearest neighbors after reducing the vector embeddings dimensions using UMAP. It’s a little more exciting to have a visual, I suppose.
Looking at the results, below, it is apparent that our embeddings have learned word relationships based on the cooccurrences within the corpus.
def reduce_umap(vectors, n_components=2, normalize=True, n_neighbors=10, min_dist=0.1, seed=13265):
"""simple dimensionality reduction for plotting"""
arr = np.array(vectors)
if normalize:
arr = arr / np.linalg.norm(arr, axis=1, keepdims=True)
reducer = umap.UMAP(
n_components=n_components,
n_neighbors=n_neighbors,
min_dist=min_dist,
random_state=seed,
metric='cosine',
n_jobs=1
)
return reducer.fit_transform(arr)
def plot_neighbors(word: str, vectors: jnp.ndarray, k: int = 5, theme="light"):
"""
Create a 2D Plotly scatterplot of nearest neighbors for a given word.
Parameters
----------
word: str
Target word.
vectors: jnp.ndarray
Trained embedding matrix.
k: int
Number of neighbors.
theme: str
The plots theme (light, dark)
Returns
-------
plotly.graph_objects.Figure
"""
add_figure_templates()
if theme == "light":
template = pio.templates["mantine_light"]
else:
template = pio.templates["mantine_dark"]
# Get neighbors
neighbors = nearest_neighbors(word, vectors, k=k)
words = [word] + neighbors
# Get their embeddings
indices = np.array([tokens_to_ids(w, word_to_index)[0] for w in words])
emb = np.array(vectors[indices])
# Reduce to 2D
emb_2d = reduce_umap(emb, 2, True, k)
# plotly scatter plot
fig = go.Figure()
fig.add_trace(
go.Scatter(
x = emb_2d[:,0],
y = emb_2d[:,1],
mode="markers+text",
marker = dict(
color=["red"] + ["blue"]*k,
symbol=["star"] + ["circle"]*k,
size=14
),
text=words,
textposition="top center"
)
)
return fig.update_layout(
title = f"Nearest neighbors of '{word}'",
template=template
)plot_neighbors("pennsylvania", embeddings, 10, "light")plot_neighbors("pennsylvania", embeddings, 10, "dark")plot_neighbors("nobles", embeddings, 10, "light")plot_neighbors("nobles", embeddings, 10, "dark")plot_neighbors("war", embeddings, 10, "light")plot_neighbors("war", embeddings, 10, "dark")Conclusion
In this short post we did a deep dive on the GloVE algorithm. Specifically, we focused on how training examples are generated, how the loss function is defined, and implementing it with JAX.