How Machines Comprehend Language
Global Vectors for Word Representation

Natural Language Processing
Author

Jonathan Dekermanjian

Published

September 13, 2025

Overview

A deep dive on the Global Vectors for Word Representation (GloVe) algorithm with an implementation in JAX.

Introduction

In this short post, we’ll do a deep dive on the Global Vectors for Word Representation (GloVe) algorithm. Like Skip-Gram with Negative Sampling (SGNS), GloVe produces dense, static embeddings. These embeddings have a fixed dimension (commonly 300), and each token in the vocabulary is assigned a single embedding, irrespective of multiple word senses.

GloVe learns word relationships from global co-occurrence statistics. To begin, we construct a co-occurrence matrix over our corpus, which in this case is the same dataset we used in the Skip-Gram With Negative Sampling post.

The core idea is straightforward: the model minimizes the weighted squared error between the logarithm of observed co-occurrences and their predicted values. Let’s dive in!

Preprocessing Utilities

We are going to use the same preprocessing functions (tokenization, tokens-to-index, etc…) that we used when we went over the Skip-Gram With Negative Sampling algorithm. These are shown again below for your reference.

# Imports
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, jit, grad
from flax import linen as nn
from nltk.tokenize import word_tokenize, sent_tokenize
from collections import Counter
import optax
import plotly.graph_objects as go
import plotly.io as pio
import umap
from figure_template import add_figure_templates
def tokenize_and_filter(text: str, min_occurrence: int) -> list[str]:
    """
    Tokenizes input text and filters out tokens that occur less than or equal to `min_occurrence` times.

    Parameters
    ----------
    text: str
        Input text to be tokenized.
    min_occurrence: int
        Minimum number of times a token must appear in the text to be included.

    Returns
    -------
    list[str]: 
        A list of tokens that occur more than `min_occurrence` times.
    """
    # Convert text to lowercase and tokenize
    tokens = word_tokenize(text.lower().strip())

    # Count occurrences of each token
    token_counts = Counter(tokens)

    # Filter tokens based on minimum occurrence threshold
    filtered_tokens = [token for token in tokens if token_counts[token] > min_occurrence]

    return filtered_tokens
def build_word_index(text: str, min_occurence: int = 0) -> tuple[dict[str, int], list[str]]:
    """
    Builds a word-to-index mapping (vocabulary) from input text, including a special <unk> token.

    Parameters
    ----------
    text: str
        Input text to extract vocabulary from.
    min_occurrence: int
        Minimum number of times a token must appear in the text to be included.

    Returns
    -------
    Tuple[Dict[str, int], List[str]]: 
        - word_to_index: A dictionary mapping each word to a unique integer ID.
        - vocabulary: A sorted list of unique vocabulary words including "<unk>".
    """
    # Tokenize and get unique words
    unique_tokens = set(tokenize_and_filter(text, min_occurrence=min_occurence))

    # Add special token for unknown words
    vocabulary = sorted(unique_tokens) + ["<unk>"]

    # Create word-to-index mapping
    word_to_index = {word: idx for idx, word in enumerate(vocabulary)}

    return word_to_index, vocabulary
def tokens_to_ids(text: str, word_to_index: dict[str, int], min_occurence: int = 0) -> list[int]:
    """
    Converts a text string into a list of integer token IDs using a provided word-to-index mapping.
    Unknown tokens are mapped to the ID of the "<unk>" token.

    Parameters
    ----------
    text: str
        Input text to convert.
    word_to_index: dict[str, int]
        A dictionary mapping words to their corresponding integer IDs.
    min_occurrence: int
        Minimum number of times a token must appear in the text to be included.

    Returns
    -------
    list[int]:
        A list of integer IDs representing the tokenized input text.
    """
    tokens = tokenize_and_filter(text, min_occurrence=min_occurence)
    token_ids = [word_to_index.get(token, word_to_index["<unk>"]) for token in tokens]
    return token_ids

Global Vectors for Word Representation (GloVe)

Conceptually, GloVe is simpler than Skip-Gram with Negative Sampling (SGNS) because it avoids negative sampling and works directly with co-occurrence statistics. However, in practice SGNS can be easier to scale to very large corpora, since GloVe requires building and storing a co-occurrence matrix in memory.

That being said let’s take a look at how we can generate our co-occurrence matrix.

Co-occurence Matrices

In addition to creating the matrix, we need a utility that uses this matrix and returns a tuple of row indices, column indices, and co-occurrence counts. This is how we will pass the information to our loss function.

def build_cooc_matrix(corpus_tokens: list[str], vocabulary: list[str], word_to_index: dict[str, int], window_size: int = 5) -> np.ndarray:
    """
    Build a co-occurrence matrix of a specified window size.
    
    Parameters
    ----------
    corpus_tokens: list[str] 
        tokenized text corpus
    vocabulary: list[str]
        sorted vocabulary
    word_to_index: dict[str, int]
        mapping from word to index in vocab
    window_size: int 
        number of words to consider on each side
    
    Returns
    -------
    np.ndarray 
        cooc of shape (vocabulary_size, vocabulary_size)
    """
    vocab_size = len(vocabulary)
    cooc_matrix = np.zeros((vocab_size, vocab_size), dtype=np.float32)
    
    for idx, word in enumerate(corpus_tokens):
        if word not in word_to_index:
            continue  # skip OOV words
        i = word_to_index[word]
        
        # Context window boundaries
        start = max(0, idx - window_size)
        end = min(len(corpus_tokens), idx + window_size + 1)
        
        for j in range(start, end):
            if j == idx:
                continue  # skip the word itself
            context_word = corpus_tokens[j]
            if context_word not in word_to_index:
                continue
            k = word_to_index[context_word]
            
            distance = abs(j - idx)
            cooc_matrix[i, k] += 1.0 / distance  # inverse distance weighting
    
    return cooc_matrix


def cooc_matrix_to_arrays(cooc_matrix: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Convert a dense co-occurrence matrix to separate arrays of indices and counts.
    
    Parameters
    ----------
    cooc_matrix: np.ndarray
        cooccurrence matrix
        
    Returns
    -------
    i_indices: np.ndarray[int32]
    j_indices: np.ndarray[int32]
    counts: np.ndarray[float32]
    """
    i_indices, j_indices = np.nonzero(cooc_matrix)
    counts = cooc_matrix[i_indices, j_indices]
    
    return (
        i_indices.astype(np.int32),
        j_indices.astype(np.int32),
        counts.astype(np.float32),
    )

Loss Function

The loss function is straightforward: we approximate the logarithm of co-occurrence counts. The model uses two embedding matrices (for target and context words) and a bias term for each. The predicted value is the dot product of the embeddings plus the biases, and we minimize the squared error between this prediction and the true log co-occurrence. Importantly, the error is weighted by a function of the co-occurrence count, which downweights very rare pairs and caps the influence of very frequent ones.

More formally the loss function is: \[ J = \sum_{i,j}f(X_{ij})(w_{i}^{T}u_{j} + b_{i} + b_{j} - logX_{ij})^{2} \]

Where \(f()\) is the weighting function, \(w\) and \(u\) are the embeddings matrices, and \(b\)’s are the biases.

def weighting_fn(x: np.ndarray, xmax: int = 100, alpha: float = 0.75):
    return jnp.where(x < xmax, (x / xmax) ** alpha, 1.0)

def glove_loss(params: tuple[jnp.ndarray], cooc_i: jnp.ndarray, cooc_j: jnp.ndarray, X_ij: jnp.ndarray):
    W, W_tilde, b, b_tilde = params
    wi = W[cooc_i]
    wj = W_tilde[cooc_j]
    bi = b[cooc_i]
    bj = b_tilde[cooc_j]

    pred = jnp.sum(wi * wj, axis=1) + bi + bj
    logX = jnp.log(jnp.maximum(X_ij, 1e-10))
    fX = weighting_fn(X_ij)
    loss = jnp.sum(fX * (pred - logX) ** 2)
    return loss

Learning Embeddings

Now, for the hands-on portion, we are going to use the same dataset that we compiled in the previous post. We are going to load the data and create our tokens, indexes, and tokenize sentences exactly as we did previously. However, this time around we are going to pre-compute the co-occurence matrix and structure the elements of the matrix into row indexes, column indexes, and counts.

Load and Process Dataset

with open('./political_corpus.txt', 'r', encoding='utf-8') as file:
    text = file.read()
raw_tokens = tokenize_and_filter(text, min_occurrence=10)
word_to_index, vocabulary = build_word_index(text=text, min_occurence=10)
index_to_word = {v:k for k,v in word_to_index.items()}
sentences = sent_tokenize(text.lower().strip())
cooc_matrix = build_cooc_matrix(raw_tokens, vocabulary, word_to_index, 10)
i_indices, j_indices, counts = cooc_matrix_to_arrays(cooc_matrix)

Define Training Loop

Our training loop is very similar to what we had in the previous post, however, we wrap everything up in functions for a cleaner representation.

def nearest_neighbors(word: str, vectors: jnp.ndarray, k:int = 3) -> list[str]:
    """
    Finds k nearest neighbors to the given word using cosine similarity.

    Parameters
    ----------
    word: str
        Target word as a string.
    vectors:  
        Trained Embeddings
    k: int
        Number of nearest neighbors to return

    Returns
    -------
    List[str]
        Nearest neighbor tokens to word token
    """
    index = tokens_to_ids(word, word_to_index)
    query = vectors[index[0]]

    # Normalize vectors and query
    vectors_norm = vectors / jnp.linalg.norm(vectors, axis=1, keepdims=True)
    query_norm = query / jnp.linalg.norm(query)

    # Compute cosine similarities
    similarities = jnp.dot(vectors_norm, query_norm)

    # Get top 10 indices (excluding the word itself)
    top_indices = jnp.argsort(similarities)[::-1][1:k+1]

    # Map back to words
    return [index_to_word[idx] for idx in np.array(top_indices)]
def log_step(epoch, params, total_loss, n_samples):
    """
    Simple function to print out summaries after each epoch
    """
    print(f"Epoch {epoch+1}, Avg Loss: {total_loss / n_samples:.4f}")
    # Log embeddings!
    embeddings = params[0] + params[1]
    print('\nLearned embeddings:')
    print(f'word: "tyranny" neighbors: {nearest_neighbors("tyranny", embeddings)}\n')

def train_glove(
    i_indices: np.ndarray, 
    j_indices: np.ndarray, 
    counts: np.ndarray, 
    vocab_size: int, 
    dim: int = 300, 
    epochs: int = 50, 
    lr: float = 1e-3, 
    batch_size: int = 512
):
    """
    
    """
    key = jax.random.PRNGKey(0)
    W = jax.random.normal(key, (vocab_size, dim)) * 0.01
    W_tilde = jax.random.normal(key, (vocab_size, dim)) * 0.01
    b = jnp.zeros(vocab_size)
    b_tilde = jnp.zeros(vocab_size)
    params = (W, W_tilde, b, b_tilde)

    opt = optax.adam(lr)
    opt_state = opt.init(params)

    n_samples = len(counts)

    @jax.jit
    def update(params: tuple[jnp.ndarray], opt_state, cooc_i: jnp.ndarray, cooc_j: jnp.ndarray, X_ij: jnp.ndarray):
        loss, grads = jax.value_and_grad(glove_loss)(params, cooc_i, cooc_j, X_ij)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    for epoch in range(epochs):
        total_loss = 0.0
        # shuffle each epoch
        perm = np.random.permutation(n_samples)
        for start in range(0, n_samples, batch_size):
            end = min(start + batch_size, n_samples)
            batch_idx = perm[start:end]
            cooc_i = jnp.array(i_indices[batch_idx])
            cooc_j = jnp.array(j_indices[batch_idx])
            X_ij = jnp.array(counts[batch_idx])
            params, opt_state, loss = update(params, opt_state, cooc_i, cooc_j, X_ij)
            total_loss += float(loss) * len(batch_idx)

        log_step(epoch, params, total_loss, n_samples)

    return params

Train

Alright, let’s run through our training loop!!

trained_params = train_glove(
    i_indices,
    j_indices,
    counts,
    len(vocabulary),
)

embeddings = trained_params[0] + trained_params[1]
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1757784911.179174 13105944 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
I0000 00:00:1757784911.193231 13105944 service.cc:145] XLA service 0x37a4db9f0 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757784911.193243 13105944 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1757784911.194202 13105944 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1757784911.194208 13105944 mps_client.cc:384] XLA backend will use up to 38654230528 bytes on device 0 for SimpleAllocator.
Metal device set to: Apple M4 Pro
Epoch 1, Avg Loss: 31.2775

Learned embeddings:
word: "tyranny" neighbors: ['justice', 'administration', 'view']

Epoch 2, Avg Loss: 12.0445

Learned embeddings:
word: "tyranny" neighbors: ['independence', 'election', 'justice']

Epoch 3, Avg Loss: 9.5596

Learned embeddings:
word: "tyranny" neighbors: ['justice', 'social', 'genius']

Epoch 4, Avg Loss: 7.2767

Learned embeddings:
word: "tyranny" neighbors: ['justice', 'maintenance', 'causes']

Epoch 5, Avg Loss: 5.5992

Learned embeddings:
word: "tyranny" neighbors: ['causes', 'justice', 'ii']

Epoch 6, Avg Loss: 4.4661

Learned embeddings:
word: "tyranny" neighbors: ['causes', 'iii', 'anarchy']

Epoch 7, Avg Loss: 3.6489

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'causes', 'anarchy']

Epoch 8, Avg Loss: 3.0393

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'iii']

Epoch 9, Avg Loss: 2.5749

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'ii']

Epoch 10, Avg Loss: 2.2202

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'ii']

Epoch 11, Avg Loss: 1.9369

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'anarchy', 'majority']

Epoch 12, Avg Loss: 1.7404

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'states—part']

Epoch 13, Avg Loss: 1.5700

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'anarchy']

Epoch 14, Avg Loss: 1.4433

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'states—part', 'anarchy']

Epoch 15, Avg Loss: 1.3429

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'anarchy']

Epoch 16, Avg Loss: 1.2633

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'waters', 'anarchy']

Epoch 17, Avg Loss: 1.1945

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'majority']

Epoch 18, Avg Loss: 1.1495

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'majority', 'waters']

Epoch 19, Avg Loss: 1.0993

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'ii']

Epoch 20, Avg Loss: 1.0637

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'mass']

Epoch 21, Avg Loss: 1.0349

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'arbitrary']

Epoch 22, Avg Loss: 1.0054

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'majority', 'anarchy']

Epoch 23, Avg Loss: 0.9782

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'arbitrary']

Epoch 24, Avg Loss: 0.9602

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'majority']

Epoch 25, Avg Loss: 0.9458

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'anarchy']

Epoch 26, Avg Loss: 0.9283

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'anarchy']

Epoch 27, Avg Loss: 0.9154

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'waters', 'mass']

Epoch 28, Avg Loss: 0.8958

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'arbitrary', 'mass']

Epoch 29, Avg Loss: 0.8922

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'mass', 'arbitrary']

Epoch 30, Avg Loss: 0.8830

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'anarchy', 'administration']

Epoch 31, Avg Loss: 0.8703

Learned embeddings:
word: "tyranny" neighbors: ['waters', 'anarchy', 'mass']

Epoch 32, Avg Loss: 0.8605

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'anarchy', 'waters']

Epoch 33, Avg Loss: 0.8599

Learned embeddings:
word: "tyranny" neighbors: ['ii', 'mass', 'anarchy']

Epoch 34, Avg Loss: 0.8462

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'enforcing']

Epoch 35, Avg Loss: 0.8439

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'arbitrary', 'mass']

Epoch 36, Avg Loss: 0.8337

Learned embeddings:
word: "tyranny" neighbors: ['anarchy', 'mass', 'waters']

Epoch 37, Avg Loss: 0.8279

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'arbitrary', 'instrument']

Epoch 38, Avg Loss: 0.8299

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'arbitrary', 'instrument']

Epoch 39, Avg Loss: 0.8237

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'instrument']

Epoch 40, Avg Loss: 0.8101

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'waters']

Epoch 41, Avg Loss: 0.8165

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'instrument', 'analogy']

Epoch 42, Avg Loss: 0.8076

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'instrument']

Epoch 43, Avg Loss: 0.8014

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'waters', 'enforcing']

Epoch 44, Avg Loss: 0.7988

Learned embeddings:
word: "tyranny" neighbors: ['conditions', 'mass', 'administration']

Epoch 45, Avg Loss: 0.7993

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'analogy', 'instrument']

Epoch 46, Avg Loss: 0.7960

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'mass', 'instrument']

Epoch 47, Avg Loss: 0.7839

Learned embeddings:
word: "tyranny" neighbors: ['analogy', 'mass', 'instrument']

Epoch 48, Avg Loss: 0.7944

Learned embeddings:
word: "tyranny" neighbors: ['arbitrary', 'instrument', 'mass']

Epoch 49, Avg Loss: 0.7813

Learned embeddings:
word: "tyranny" neighbors: ['mass', 'anarchy', 'analogy']

Epoch 50, Avg Loss: 0.7801

Learned embeddings:
word: "tyranny" neighbors: ['instrument', 'mass', 'anarchy']

Inspect learned embeddings

We use the same nearest neighbors approach to inspect our learned embeddings, as we did in the previous post. However, we added a plotting utility to plot the nearest neighbors after reducing the vector embeddings dimensions using UMAP. It’s a little more exciting to have a visual, I suppose.

Looking at the results, below, it is apparent that our embeddings have learned word relationships based on the cooccurrences within the corpus.

def reduce_umap(vectors, n_components=2, normalize=True, n_neighbors=10, min_dist=0.1, seed=13265):
    """simple dimensionality reduction for plotting"""
    arr = np.array(vectors)
    if normalize:
        arr = arr / np.linalg.norm(arr, axis=1, keepdims=True)
    reducer = umap.UMAP(
        n_components=n_components,
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        random_state=seed,
        metric='cosine', 
        n_jobs=1
    )
    return reducer.fit_transform(arr)

def plot_neighbors(word: str, vectors: jnp.ndarray, k: int = 5, theme="light"):
    """
    Create a 2D Plotly scatterplot of nearest neighbors for a given word.
    
    Parameters
    ----------
    word: str
        Target word.
    vectors: jnp.ndarray
        Trained embedding matrix.
    k: int
        Number of neighbors.
    theme: str
        The plots theme (light, dark)
        
    Returns
    -------
    plotly.graph_objects.Figure
    """
    add_figure_templates()
    if theme == "light":
        template = pio.templates["mantine_light"]
    else:
        template = pio.templates["mantine_dark"]

    # Get neighbors
    neighbors = nearest_neighbors(word, vectors, k=k)
    words = [word] + neighbors
    
    # Get their embeddings
    indices = np.array([tokens_to_ids(w, word_to_index)[0] for w in words])
    emb = np.array(vectors[indices])
    
    # Reduce to 2D
    emb_2d = reduce_umap(emb, 2, True, k)
    
    # plotly scatter plot
    fig = go.Figure()
    
    fig.add_trace(
        go.Scatter(
            x = emb_2d[:,0],
            y = emb_2d[:,1],
            mode="markers+text",
            marker = dict(
                color=["red"] + ["blue"]*k,
                symbol=["star"] + ["circle"]*k,
                size=14
            ),
            text=words,
            textposition="top center"
        )
    )
    return fig.update_layout(
        title = f"Nearest neighbors of '{word}'",
        template=template
    )
plot_neighbors("pennsylvania", embeddings, 10, "light")
plot_neighbors("pennsylvania", embeddings, 10, "dark")
plot_neighbors("nobles", embeddings, 10, "light")
plot_neighbors("nobles", embeddings, 10, "dark")
plot_neighbors("war", embeddings, 10, "light")
plot_neighbors("war", embeddings, 10, "dark")

Conclusion

In this short post we did a deep dive on the GloVE algorithm. Specifically, we focused on how training examples are generated, how the loss function is defined, and implementing it with JAX.