How Machines Comprehend Language
Recurrent Neural Networks - Embeddings from Language Models (ELMo)

Natural Language Processing
Author

Jonathan Dekermanjian

Published

January 31, 2026

Overview

We go over contextual embeddings and recurrent neural networks. Describing their inner workings and walking through an implementation of Embeddings from Language Models (ELMo) from scratch, using the JAX ecosystem.

Introduction

In previous posts we learned what embeddings are, why they are important and how they can be levaraged in practice. We got our hands dirty by implementing two of the earliest foundational methods GloVe and Skip-gram with negative sampling (SGNS).

If you need a refresher or just hadn’t seen them you can check out my posts on GloVe and SGNS.

In this post we will briefly review the differences between static and contextual embeddings and discuss some of the methodologies used to create contextual embeddings. Leading us to the architecture of recurrent neural networks and how they are leveraged in contextual embeddings.

Finally, we will delve into the architecture of Embeddings from Language Models (ELMo), and implement it from scratch.

From Static to Contextual Embeddings

As a reminder, a strength of contextual embeddings is the ability to discriminate between polysemous words. For example, the word “ship” can refer to the action verb of shipping an item or to the noun of a water vessel. A static embedding would represent both meanings of the word “ship” with the same vector. In contrast, contextual embeddings are able to disambiguate the two meanings of the same word.

Since our vectors now need to be aware of context, the sequence in which words occur must influence the resultant vectors. Importantly, not only during training must our vectors be aware of context but also during inference. Therefore, we must build vector embeddings dynamically (on the fly) during inference to ensure that this is the case. This contrasts with earlier methods like SGNS or GloVe:

  • SGNS trains a vector for each word by predicting its local context, but at inference it simply looks up the pre‑computed static vector—there’s no per‑sentence adaptation.

  • GloVe builds a word‑co‑occurrence matrix and factorizes it, again yielding one static vector per word that is reused for every sentence.

One way we include information from sequences is by using recurrent computation. This brings us to the topic of recurrent neural networks.

Recurrent Neural Networks

A recurrent neural network (RNN) is a neural architecture designed for sequential data, in which information from previous time steps is carried forward through a recurrent hidden state. At each time step, the hidden state is updated based on the current input and the previous hidden state, and the output is computed from this hidden representation.

We have two equations that define a simple RNN, also known as an Elman RNN.

\[ h_{t} = \phi(W_{h}h_{t-1} + W_{x}x_{t} + b) \]

\[ y_{t} = \psi(W_{y}h_{t} + b_{y}) \]

Where \(x_{t}\), \(h_{t}\), and \(y_{t}\) are the input, hidden state, and output at time t, respectively. The weight matrices \(W\) are linear operators that project variables from one space into another. For example, \(W_{x}\) projects \(x_{t}\) into the hidden state space. Notice that these weight matrices are not indexed by t, signifying that they are shared across time and allows RNNs to generalize to varying sequence lengths.

flowchart BT
  X[$$x_t$$] -->|input| H[$$h_t$$]
  H -->|output| Y[$$y_t$$]
  H -->|recurrent state| H

  style H fill:#eef,stroke:#333,stroke-width:1.5px

Figure 1: Elman (Simple) Recurrent Neural Network

The Elman RNN maintains only one hidden state, at each time step the same computation must compress all past information into a single \(h_{t}\). This is challenging to do for longer time dependencies. Additionally, Elman RNNs are prone to vanishing/exploding gradients due to the gradients needing to repeatedly pass through \(W_{h}\) and the nonlinearity \(\phi\), via backpropagation through time. In practice, these limitations are enough to opt for a modified RNN architecture.

Long Short-Term Memory

A Long Short-Term Memory (LSTM) RNN is an architectural extension of the standard Elman RNN designed to mitigate the vanishing gradient problem when modeling long-range dependencies. The LSTM introduces gating mechanisms (input, forget, and output gates) that explicitly regulate how much past information is retained and how much new information is incorporated, rather than relying on implicit storage in the recurrent weights as in an Elman RNN. Crucially, the LSTM maintains a cell state whose update follows an additive structure, forming a linear path through time. These additive dynamics enables more stable gradient propagation and substantially reduces vanishing gradients during training.

Let’s dive a little bit deeper here. Given an input at time \(t\), \(x_{t}\), and a previous hidden state \(h_{t-1}\) and a previous cell state \(c_{t-1}\), we have the following equations that control the evolution of the hidden and cell states over time.

First, the forget gate \(f_{t}\) is a parametric control for the signal that should be propagated forward from the prior cell state \(c_{t-1}\) to the current cell state \(c_{t}\) that we are computing.

\[ f_{t} = \sigma(W_{f} \cdot [h_{t-1}, x_{t}] + b_{f}) \]

The input gate \(i_{t}\), you’ll notice, is almost the same computation as the forget gate and its role is to regulate how much signal should pass through from the candidate cell state \(\tilde{c}_{t}\) to the current cell state \(c_{t}\).

\[ i_{t} = \sigma(W_{i} \cdot [h_{t-1}, x_{t}] + b_{i}) \]

Our last gate, the output gate \(o_{t}\) is again the same computation as the other gates and its role is to regulate how much of the current cell state \(c_{t}\) is exposed to the current hidden state \(h_{t}\).

\[ o_{t} = \sigma(W_{o} \cdot [h_{t-1}, x_{t}] + b_{o}) \]

You may be thinking if all the computations are the same how do the different gates perform different functions? Well, that is a good question and one that is not obvious. However, their distinct roles emerge from how each gate modulates different computational paths in the forward pass, and how gradients propagate through those paths during backpropagation.

In the remaining equations you can see how these gates act as controls on the input, cell state, and hidden state.

\[ \tilde{c}_{t} = \tanh(W_{c} \cdot [h_{t-1}, x_{t}] + b_{c}) \]

\[ c_{t} = f_{t} \odot c_{t-1} + i_{t} \odot \tilde{c}_{t} \]

\[ h_{t} = o_{t} \odot \tanh(c_{t}) \]

Below is a graph that depicts the flow of computations that take place to update the hidden state.

flowchart BT
    x_t["$$x_{t}$$ (input)"] --> concat
    h_prev["$$h_{t-1}$$ (prev hidden)"] --> concat

    concat["Concatenate"] --> f_gate["Forget Gate<br> $$\sigma(W_f \cdot [h_{t-1}, x_{t}] + b_f)$$"]
    concat --> i_gate["Input Gate<br/> $$\sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$"]
    concat --> c_tilde["Candidate State<br/> $$\tanh(W_c \cdot [h_{t-1}, x_t] + b_c)$$"]
    concat --> o_gate["Output Gate<br/> $$σ(W_o \cdot [h_{t-1}, x_t] + b_o)$$"]

    c_prev["$$c_{t-1}$$ (prev cell)"] --> mult_f["×"]
    f_gate --> mult_f

    i_gate --> mult_i["×"]
    c_tilde --> mult_i

    mult_f --> c_t["cₜ (cell state)"]
    mult_i --> c_t

    c_t --> tanh_c["$$\tanh$$"]
    tanh_c --> mult_o["×"]
    o_gate --> mult_o

    mult_o --> h_t["$$h_t$$ (hidden state)"]

Figure 2: Long Short-Term Memory Recurrent Neural Network

Embeddings from Language Models

Now that we better understanding of RNNs, specifically LSTMs, we are ready to talk about a deep learning architecture used to generate contextual embeddings. The influential Embeddings from Language Models (ELMo) architecture.

Before transformers became popular ELMo was one of the dominant models for generating contextual embeddings. ELMo blends character‑level convolutional neural networks (CNN) token encodings for handling out-of-vocabulary (OOV) words, multi‑layer bidirectional LSTMs to capture both left-to-right and right-to-left long range dependencies, and a trainable weighted sum of their hidden states to produce word‑level embeddings that are both context‑sensitive and adaptable through fine‑tuning.

Architecture

Let’s look at a high-level map of ELMo’s architecture. The important sub-components are:

  1. Character Level CNN: Converts each token into character-level feature vectors and is well suited to handling OOV words.
  2. Multi-layer bi-directional LSTM: processes the character vectors in both directions, producing hidden states that capture left and right‑context.
  3. Contextual Hidden State: combines forward and backward outputs to form a deep, context‑sensitive representation of each token.
  4. Scalar Mix parameters: Learn a task-specific weighted combination of the representations from different layers, enabling effective transfer of the pretrained language model to downstream tasks without retraining the full model.

Below is a high-level graph of an ELMo model.

flowchart BT

%% 1 Input
Tokens["$$\text{Token IDs } T_{1} \text{ to } T_{n}$$"] --> CharEmb[Character Embedding]

%% 2 Character CNN (h0)
CharEmb --> CharCNN[Character CNN Output]
CharCNN --> H0["$$h_{0} \text{ Character-based token representation}$$"]

%% 3 BiLSTM layers
subgraph BiLSTM[BiLSTM]
    direction RL

    Title[Bidirectional LSTM Stack]
    style Title fill:none,stroke:none

    subgraph Layer2[Layer 2]
        F2[Forward LSTM 2] --> HF2[h_fwd_2]
        B2[Backward LSTM 2] --> HB2[h_bwd_2]
    end

    subgraph Layer1[Layer 1]
        F1[Forward LSTM 1] --> HF1[h_fwd_1]
        B1[Backward LSTM 1] --> HB1[h_bwd_1]
    end
end

H0 --> BiLSTM

%% 4 Contextual states
subgraph HiddenStates[Contextual Representations]
    H1["$$h_{1} = \text{concat fwd1 bwd1}$$"]
    H2["$$h_{2} = \text{concat fwd2 bwd2}$$"]
end

HF1 --> H1
HB1 --> H1
HF2 --> H2
HB2 --> H2

%% 5 Forward & Backward LM heads
subgraph LMHeads[LM Heads]
    direction LR

    LMTitle[Language Modeling Heads]
    style LMTitle fill:none,stroke:none

    FLM[Forward LM Head] --> FSoftmax["Softmax(vocab)"]
    BLM[Backward LM Head] --> BSoftmax["Softmax(vocab)"]
end

HF1 --> FLM
HF2 --> FLM
HB1 --> BLM
HB2 --> BLM

%% 6 ELMo scalar mixer
subgraph Mixer[ELMo Scalar Mixer]
    A0["$$a_{0}$$"] --> Softmax["$$\text{Softmax over } a_{i}$$"]
    A1["$$a_{1}$$"] --> Softmax
    A2["$$a_{2}$$"] --> Softmax

    Softmax --> S0["$$s_{0}$$"]
    Softmax --> S1["$$s_{1}$$"]
    Softmax --> S2["$$s_{2}$$"]

    Sum["$$\text{Sum }s_{i}h_{i}$$"]
    Gamma[Gamma]
end

H0 --> Sum
H1 --> Sum
H2 --> Sum

S0 --> Sum
S1 --> Sum
S2 --> Sum

Sum --> Gamma

%% 7 ELMo output
subgraph ELMO[ELMo Embedding]
    ELMOvec["ELMo(t)"]
end

Gamma --> ELMOvec

style LMHeads fill:#5eaabf,stroke:#5eaabf,stroke-width:2px

Figure 3: ELMo Architecture

Pretraining Objective

ELMo uses a bidirectional language modeling objective to maximize the likelihood of observed tokens given all surrounding context. For ease of understanding, we can breakdown the objective into two directional components.

First, we have the forward language model objective which predicts each token \(x_{t}\) given all prior tokens.

\[ P_{fwd}(x) = \prod_{t=1}^{T} P(x_{t}|x_{1}...x_{t-1}) \]

Second, we have the backwards language model objective which as the name implies predicts each token \(x_{t}\) given all proceeding tokens.

\[ P_{bwd}(x) = \prod_{t=1}^{T} P(x_{t}|x_{T}...x_{t+1}) \]

The total pretraining objective is the sum of the forward and backward negative log-likelihoods:

\[ \mathcal{L} = - \sum_{t=1}^{T} \log P_{fwd}(x_t | x_{<t}) \;-\; \sum_{t=1}^{T} \log P_{bwd}(x_t | x_{>t}) \]

Hands-on Implementation

Okay then, time to get our hands dirty! We are going to build an ELMo model from scratch using the JAX deep learning ecosystem.

Specifically, we will pre-train the model using the aforementioned language modeling objective. Subsequently, we will fine-tune our scalar mix on a downstream task to evaluate our embeddings. I hope you enjoy!

In this part of the post, I’ll collapse all but the most important code cells, leaving only the ones I’ll discuss in depth open.

Processing Utilities

To get started we are going to need some toy data. I chose the c4 dataset provided by AllenAI because it had nice long textual examples that were relatively clean.

In order to keep things running on a local machine I took one million examples for training and 50,000 examples for validation. I personally streamed the data to my local disk for faster and easier of reinitialization. However, with the streaming dataset utility function we define you can also just stream the data directly from HuggingFace.

We also need to initialize our vocabulary, at the character level for our character level CNN and at the word level to feed into our language modeling head.

For our character level vocabulary we simply use the characters that you’d typically see in english text, in addition to an index for a padding character and an unknown character.

Our word vocabulary is specified very similarly to how we have done it in the previous posts. We build a word vocabulary of up to 100,000 words with a minimum word frequency of 20 words.

Note

We use a simple tokenizer, splitting text on spaces, which is not canonically what ELMo utilizes in the original paper but works almost just as well.

Finally, as I eluded to earlier we need a way to stream batches of our data to the model as we are training. The utility defined below handles shuffling, tokenizing, and encoding both characters and words before yielding them to the model.

Important

I have elected to use non-overlapping sequences because it is a lot faster and I am running on a local machine. For best results it would be best to use overlapping sequences. You can adjust this in the __iter__ method.

Code
import os

from functools import partial
from collections import Counter
import itertools
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import numpy as np

from datasets import load_dataset, DownloadConfig, Dataset, load_from_disk
import string
import random
import orbax.checkpoint as ocp

from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
Code
# Stream data to local disk: one time only
# You can also just stream from HF, however on reload (i.e. next epoch) you need to stream again (re-download).

def train_gen():
    """
    Streams english c4 training data from HF Allen AI. We subset the data to 1M examples.
    """
    stream = load_dataset(
        "allenai/c4",
        "en",
        split="train",
        streaming=True,
    )
    for i, row in enumerate(stream):
        if i >= 1_000_000:
            break
        yield row

def val_gen():
    """
    Streams english c4 validatiom data from HF Allen AI. We subset the data to 50k examples.
    """
    stream = load_dataset(
        "allenai/c4",
        "en",
        split="train",
        streaming=True,
    )
    for i, row in enumerate(stream):
        if i < 2_000_000:
            continue
        if i >= 2_050_000:
            break
        yield row


# In batches of 1k save the data to disk
train_ds = Dataset.from_generator(
    train_gen,
    writer_batch_size=1_000,
)

train_ds.save_to_disk("c4_train")

val_ds = Dataset.from_generator(
    val_gen,
    writer_batch_size=1_000,
)

val_ds.save_to_disk("c4_val")
Code
# Tokens used when padding or encountering unknown characters in a string.
PAD_CHAR, UNK_CHAR = "<pad>", "<unk>"

# Tokens used when padding or encountering unknown words in a sentence.
PAD_WORD, UNK_WORD = "<pad>", "<unk>"

# Characters that we expect to see in typical English text
char_vocab = [PAD_CHAR, UNK_CHAR] + list(string.ascii_lowercase + string.ascii_uppercase + string.digits + string.punctuation)
char_to_id = {ch: i for i, ch in enumerate(char_vocab)}

def build_word_vocab(dataset, batch_size=32, min_freq=20, max_vocab=100_000) -> tuple[dict[str, int], dict[int, str]]:
    """
    Build a word–to–index mapping from a text dataset.

    The function scans the provided dataset for tokenised words
    (splitting on whitespace) and retains only words that appear at least
    :param min_freq: times in the corpus.  The returned dictionary is
    capped to at most :param max_vocab: entries.

    Parameters
    ----------
    dataset
        Iterable of records where each record is a mapping
        (``dict``) that contains a ``"text"`` key holding the raw string.
    batch_size
        Number of records processed in a single counting pass.
    min_freq
        Minimum frequency required for a word to be included in the
        resulting vocabularies.
    max_vocab
        Maximum number of words (excluding ``<pad>`` and ``<unk>``) to keep in
        the vocabularies.  The words chosen are the most frequent ones that
        pass ``min_freq`` filtering.

    Returns
    -------
    word_to_index : dict
        Mapping from a word string to a unique integer ID, with
        ``<pad>`` mapping to 0 and ``<unk>`` mapping to 1.
    index_to_word : dict
        Inverse mapping from integer ID back to the corresponding word.

    Notes
    -----
    * The tokenisation strategy is simplistic: it uses the standard
      ``str.split()`` which splits on any whitespace.  For more advanced
      tokenisation pipelines (e.g., handling sub‑words, hyphenated
      compounds, etc.) replace the ``s.split()`` call accordingly.
    * The frequency counter is updated in batches to keep the memory
      footprint low; ``Counter.update`` is called repeatedly on
      concatenated lists of tokens.

    """
    counter = Counter()
    buf = []

    for row in dataset:
        buf.append(row["text"])
        if len(buf) >= batch_size:
            for s in buf:
                counter.update(s.split())
            buf = []

    if buf:
        for s in buf:
            counter.update(s.split())

    # top words with min frequency
    most_common = [(w, c) for w, c in counter.most_common(max_vocab) if c >= min_freq]
    word_to_index = {PAD_WORD: 0, UNK_WORD: 1}
    index_to_word = {0: PAD_WORD, 1: UNK_WORD}
    for i, (w, _) in enumerate(most_common, start=2):
        word_to_index[w] = i
        index_to_word[i] = w
    return word_to_index, index_to_word
Code
class StreamingTextDataLoader:
    """
    A streaming data loader for text corpora that yields batches of tokenized
    sequences.

    This class supports streaming from an arbitrary dataset, automatically shuffling,
    tokenizing into fixed‑length windows (with stride = seq_len), and encoding both
    words and characters. It provides batched dictionaries containing ``word_ids``,
    ``char_ids`` and ``target_ids`` as JAX arrays.

    Parameters
    ----------
    ds: Iterable[dict]
        A dataset yielding rows that contain a text field (e.g., "text", "sentence",
        "review", or "content").  Each row should be a mapping from column name to value.
    vocab: dict[str, int] | Vocab object
        Mapping from token strings to unique integer IDs.
    char_to_id: dict[str, int]
        Mapping from characters to their integer IDs.
    seq_len: int
        Length of the input sequence window (number of tokens). Each yielded batch will contain sequences of this size.
    word_len: int
        Maximum length of each token when encoded at the character level. Tokens longer than this are truncated; shorter ones are padded with ``PAD_CHAR``.
    batch_size: int
        Number of examples to accumulate before yielding a batched dictionary.
    shuffle_buffer: int, optional (default=2048)
        Size of the buffer used for shuffling tokens from the stream. Larger buffers
        provide better randomness at the cost of memory usage.
    seed: int or None, optional (default=0)
        Random seed for reproducibility when shuffling.

    Attributes
    ----------
    self.ds: original iterable dataset.
    self.vocab: token vocabulary.
    self.char_to_id: character‑to‑ID mapping.
    self.seq_len: sequence length used for chunking.
    self.word_len: maximum character width per token.
    self.batch_size: number of samples per batch to yield.
    self.shuffle_buffer: size of the shuffle buffer.
    self.seed: RNG seed.
    self.token_buffer: internal queue holding pre‑assembled tokens awaiting windowing.

    Yields
    ------
    dict
        A dictionary with three JAX arrays:
        ``word_ids`` – shape ``(batch_size, seq_len)`` integer IDs for each token,
        ``char_ids`` – shape ``(batch_size, seq_len, word_len)`` character‑level IDs,
        ``target_ids`` – shifted version of ``word_ids`` used as language‑model targets.

    Notes
    -----
    * The class performs a non‑overlapping stride split (`self.token_buffer =
      self.token_buffer[self.seq_len:]`) which can be changed to an overlapping stride
      if desired for better coverage.
    * All returned arrays are JAX ``jnp`` objects.
    """
    def __init__(self, ds, vocab, char_to_id, seq_len, word_len, batch_size, shuffle_buffer=2048, seed=0):
        self.ds = ds
        self.vocab = vocab
        self.char_to_id = char_to_id
        self.seq_len = seq_len
        self.word_len = word_len
        self.batch_size = batch_size
        self.shuffle_buffer = shuffle_buffer
        self.seed = seed

        self.token_buffer = []

    def _get_text_field(self, row):
        for key in ["text", "sentence", "review", "content"]:
            if key in row:
                return row[key]
        raise KeyError(f"No text field found in row keys: {list(row.keys())}")


    def _encode_window(self, toks):
        word_ids = [self.vocab.get(w, self.vocab.get(UNK_WORD)) for w in toks]

        char_ids = np.full(
            (self.seq_len, self.word_len),
            self.char_to_id[PAD_CHAR],
            dtype=np.int32,
        )

        for i, w in enumerate(toks):
            cids = [self.char_to_id.get(c, self.char_to_id[UNK_CHAR]) for c in w[:self.word_len]]
            char_ids[i, :len(cids)] = cids

        return {
            "word_ids": np.array(word_ids, dtype=np.int32),
            "char_ids": char_ids,
        }

    def _shuffle_buffer_iter(self, ds):
        buf = []
        for row in ds:
            buf.append(row)
            if len(buf) >= self.shuffle_buffer:
                random.shuffle(buf)
                while buf:
                    yield buf.pop()
        random.shuffle(buf)
        while buf:
            yield buf.pop()

    def __iter__(self):
        self.token_buffer = []
        batch_words, batch_chars, batch_targets = [], [], []

        for row in self._shuffle_buffer_iter(self.ds):
            text = self._get_text_field(row)

            new_toks = text.split()
            self.token_buffer.extend(new_toks)

            while len(self.token_buffer) >= self.seq_len + 1:
                x_toks = self.token_buffer[:self.seq_len]
                y_toks = self.token_buffer[1:self.seq_len + 1]

                # Non-overlapping stride (for speed) ideally you want overlapping
                self.token_buffer = self.token_buffer[self.seq_len:]

                x_enc = self._encode_window(x_toks)
                y_ids = np.array(
                    [self.vocab.get(w, self.vocab.get(UNK_WORD)) for w in y_toks],
                    dtype=np.int32,
                )

                batch_words.append(x_enc["word_ids"])
                batch_chars.append(x_enc["char_ids"])
                batch_targets.append(y_ids)

                if len(batch_words) == self.batch_size:
                    yield {
                        "word_ids": jnp.stack(batch_words),
                        "char_ids": jnp.stack(batch_chars),
                        "target_ids": jnp.stack(batch_targets),
                    }
                    batch_words, batch_chars, batch_targets = [], [], []

Build Vocabulary

If you have also opted to save the toy dataset to disk, then the next step is to load it and build the word vocabulary. Note that we have already built the character vocabulary above.

Code
hf_ds = load_from_disk("c4_train")
vocab, _ = build_word_vocab(hf_ds)

Define Architecture Components

We went over the bulk of the architecture up above with the diagrams and the LSTM overview, however, I wanted to focus on a component within the character-level CNN called the highway.

The CNN applies a set of parallel one‑dimensional convolutions with varying kernel widths over the character sequences. Each filter’s output is max‑pooled across the temporal dimension, and the pooled features from all filters are concatenated to form a fixed‑size vector. This vector is then fed into a highway network, which learns a gated mixture of the raw CNN output and a transformed version of it, yielding the final character‑level embedding.

More specifically, we have our transformed piece of the input vector \(x\) defined as

\[ trans = ReLU(W_{t}x + b_{t}) \]

and learnable gates defined as

\[ gate = \sigma(W_{g}x + b_{g}) \]

and we blend the two together with the following

\[ highway(x) = gate(x) \odot trans + (1 - gate(x)) \odot x \]

You’ll notice that it is possible for gates to be zero implying that the raw CNN outputs flow through unchanged. The highway layers add robustness to vanisihing/exploding gradients and to noisy CNN outputs. They also allow learning more diverse representations because the network can keep low-level, high-frequency patterns or mix them into higher-level, smoothed features.

class Highway(nnx.Module):
    """
    A simple ``Highway`` network module

    The layer implements the classic gating mechanism that controls how much of the
    transformed input should be let through versus left for a direct residual shortcut.

    Parameters
    ----------
    dim : int
        Dimensionality of the input and output vectors.
    rngs : nnx.Rngs
        Random number generators.

    Returns
    -------
    nnx.Module
        A callable module whose ``__call__(self, x)`` method returns

        .. code-block:: python

            H * T + x * (1 - T)

        where

        * ``proj(x)`` --> ``H = relu(proj(x))`` is the transformed (candidate) signal,
        * ``trans(x)`` --> ``T = sigmoid(trans(x))`` is a gate in ``[0, 1]``,
        * ``x * (1 - T)`` passes through the original input weighted by the complement of
          the gate.

    """
    def __init__(self, dim, *, rngs: nnx.Rngs):
        self.proj = nnx.Linear(dim, dim, rngs=rngs)
        self.trans = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x):
        H = jax.nn.relu(self.proj(x))
        T = jax.nn.sigmoid(self.trans(x))
        return H * T + x * (1 - T)


class CharCNN(nnx.Module):
    """
    Character‑level convolutional encoder that produces a dense
    representation for each token in a sequence.

    The module first embeds each input character, applies a set of 1‑D
    convolutions across the character dimension, performs a global
    max‑pool, concatenates the filter responses, and optionally
    processes them through Highway layers and a final projection
    layer.

    Parameters
    ----------
    vocab_size : int
        Size of the character vocabulary.  Must be at least the number
        of distinct characters (plus any padding/unknown tokens) used
        in the input data.

    char_dim : int
        Dimensionality of the learned character embeddings.

    filters : Sequence[tuple[int, int]]
        A list of `(width, out_channels)` pairs that specify the
        kernel width and number of output channels for each 1‑D
        convolution.

    highway_layers : int
        Number of Highway layers applied after the convolution
        stack.

    proj_dim : int | None, default ``None``
        If given, a final linear projection is applied to the
        concatenated convolution+highway features to reduce (or
        expand) the dimensionality to ``proj_dim``.  If ``None``,
        the output dimensionality equals the total number of
        convolution channels.

    rngs : nnx.Rngs
        Random number generator.

    Notes
    -----
    * **Input shape**: ``char_ids`` must have shape
      ``[B, T, W]`` where ``B`` is the batch size, ``T`` the number of
      tokens per sequence, and ``W`` the maximum number of characters
      per token (words are padded to this length).
    * Each convolution operates on the character dimension.
      After convolution it is followed by a ReLU activation and a
      channel‑wise global max‑pool over the remaining character
      positions, resulting in a single scalar per filter channel.
    * If ``proj_dim`` is ``None`` the output shape will be
      ``[B, T, total_filters]`` where
      ``total_filters`` is the sum of all ``out_channels`` across
      the filters.  If a projection is used the output shape is
      ``[B, T, proj_dim]``.

    Returns
    -------
    jnp.ndarray
        Character‑derived embedding of shape ``[B, T, proj_dim]`` if a
        projection layer is supplied, otherwise ``[B, T, total_filters]``.
    """
    def __init__(self, vocab_size, char_dim, filters, highway_layers, proj_dim=None, *, rngs: nnx.Rngs):
        self.emb = nnx.Embed(vocab_size, char_dim, rngs=rngs)
        self.convs = nnx.List([nnx.Conv(
            in_features=char_dim,
            out_features=out_channels,
            kernel_size=(width,),
            feature_group_count=1,
            use_bias=True,
            rngs=rngs
        )
        for (width, out_channels) in filters])
        # highway layers
        total_filters = sum(out for _, out in filters)
        self.highways = nnx.List([Highway(total_filters, rngs=rngs) for _ in range(highway_layers)])
        self.proj_dim = proj_dim
        if proj_dim is not None:
            self.proj = nnx.Linear(total_filters, proj_dim, rngs=rngs)

    def __call__(self, char_ids):
        B, T, W = char_ids.shape
        # embed -> [B, T, W, D]
        x = self.emb(char_ids)
        # apply convs along the word-length axis: we first reshape to merge batch/time
        x_flat = x.reshape((B*T, W, x.shape[-1]))  # [B*T, W, D]
        conv_outs = []
        for conv in self.convs:
            y = conv(x_flat)  # [B*T, new_len, out_ch]
            y = jax.nn.relu(y)
            y = jnp.max(y, axis=1)  # max-pool over positions
            conv_outs.append(y)
        x_cat = jnp.concatenate(conv_outs, axis=-1)  # [B*T, total_filters]
        # highways
        for h in self.highways:
            x_cat = h(x_cat)
        if self.proj_dim is not None:
            x_cat = self.proj(x_cat)  # [B*T, proj_dim]
        return x_cat.reshape((B, T, -1))  # [B, T, embed_dim]
Code
class LSTMCell(nnx.Module):
    """
    A single LSTM cell implemented in ``flax.nnx`` that supports
    dropout on both the input and the recurrent connection.

    Parameters
    ----------
    input_dim : int
        Dimensionality of the input vector ``x_t``
    hidden_dim : int
        Number of LSTM hidden units; defines the size of the hidden
        state ``h`` and cell state ``c``.
    dropout : float, default 0.0
        Dropout probability applied to both the input vector and the
        recurrent hidden state when ``deterministic=False``.
    rngs : nnx.Rngs | None, default ``None``
        Random number generators.

    Notes
    -----
    * Internal weight maps:

      ``Wx : input_dim: 4 x hidden_dim``  
      ``Wh : hidden_dim: 4 x hidden_dim``

      The four output channels correspond respectively to the
      *input*, *forget*, *output*, and *candidate* gates.

    * Dropout is applied according to the standard
      “inverted” scheme (`mask / (1‑p)`), where the mask is drawn
      from a Bernoulli distribution with probability `1‑dropout`.
      Two independent masks are used: one for the current input
      ``x_t`` and one for the recurrent hidden state ``h``.  A
      distinct RNG key must be passed via ``jax_rng`` when
      ``deterministic=False``.

    * The cell state is a tuple ``(h, c)`` where each component has
      shape ``[batch, hidden_dim]``.  The method returns a tuple
      ``((h_new, c_new), h_new)``; the outer tuple contains the
      updated cell state, and the inner ``h_new`` is the output
      vector that can be consumed by ``nnx.scan`` or another
      sequence wrapper.

    Returns
    -------
    tuple
        * ``((h_new, c_new), h_new)``  
          where ``h_new`` and ``c_new`` are arrays of shape
          ``[batch, hidden_dim]`` representing the updated
          hidden and cell states, respectively.  The second
          ``h_new`` in the outer tuple is the output of this
          cell and matches the batch dimension of x_t``.
    """
    def __init__(self, input_dim, hidden_dim, dropout=0.0, *, rngs=None):
        self.Wx = nnx.Linear(input_dim, 4 * hidden_dim, rngs=rngs)
        self.Wh = nnx.Linear(hidden_dim, 4 * hidden_dim, rngs=rngs)
        self.hidden_dim = hidden_dim
        self.dropout = dropout

    def __call__(self, carry, x_t, deterministic=True, jax_rng=None):
        h, c = carry

        if not deterministic:
            assert jax_rng is not None, "RNG key must be passed for dropout"
            rng_inp, rng_rec = jax.random.split(jax_rng)
            x_mask = jax.random.bernoulli(rng_inp, 1.0 - self.dropout, x_t.shape)
            x_t = x_t * x_mask / (1.0 - self.dropout)

            h_mask = jax.random.bernoulli(rng_rec, 1.0 - self.dropout, h.shape)
            h = h * h_mask / (1.0 - self.dropout)

        gates = self.Wx(x_t) + self.Wh(h)
        i, f, o, g = jnp.split(gates, 4, axis=-1)
        i = jax.nn.sigmoid(i)
        f = jax.nn.sigmoid(f)
        o = jax.nn.sigmoid(o)
        g = jnp.tanh(g)
        c_new = f * c + i * g
        h_new = o * jnp.tanh(c_new)
        return (h_new, c_new), h_new


class BiLSTMLayer(nnx.Module):
    """
    A bidirectional LSTM layer built on top of :class:`LSTMCell`.

    Each input sequence is processed in two directions:
    * a forward LSTM that runs from the first to the last token,
    * a backward LSTM that runs from the last to the first token.
    The outputs of the two directions are returned separately and also
    concatenated along the feature dimension.

    Parameters
    ----------
    in_dim : int
        Dimensionality of input tokens.
    hidden_dim : int
        Size of the hidden state in each direction (`h` and `c`).
    dropout : float, default 0.0
        Dropout probability applied inside each :class:`LSTMCell`.  The
        same dropout probability is used for both forward and backward
        streams.
    rngs : nnx.Rngs | None, default ``None``
        Random number generators

    Notes
    -----
    * **RNG handling** -  
      When ``deterministic=False`` a single ``jax_rng`` is split once
      into two keys (for forward and backward).  These keys are then
      split further into a per‑time‑step key that is passed to
      :class:`LSTMCell`.  For deterministic execution ``rngs_fwd`` and
      ``rngs_bwd`` are lists of ``None``.
    * **State management** -  
      Each direction starts from a zero initial hidden and cell state of
      shape ``(B, hidden_dim)``.
    * **Output dimensions** -  
      For an input of shape ``(B, T, D)`` the three returned tensors have
      shapes:
        * ``hs_fwd``: ``(B, T, hidden_dim)``
        * ``hs_bwd``: ``(B, T, hidden_dim)``
        * ``hs_concat``: ``(B, T, 2 * hidden_dim)``

    Returns
    -------
    tuple
        ``(hs_fwd, hs_bwd, hs_concat)``
        * ``hs_fwd`` - hidden states produced by the forward LSTM,
        * ``hs_bwd`` - hidden states produced by the backward LSTM,
        * ``hs_concat`` - concatenation of ``hs_fwd`` and ``hs_bwd``.

    """
    def __init__(self, in_dim, hidden_dim, dropout=0.0, *, rngs=None):
        self.fwd = LSTMCell(in_dim, hidden_dim, dropout=dropout, rngs=rngs)
        self.bwd = LSTMCell(in_dim, hidden_dim, dropout=dropout, rngs=rngs)

    def __call__(self, inputs, deterministic=True, jax_rng=None):
        B, T, D = inputs.shape
        h0 = jnp.zeros((B, self.fwd.hidden_dim))
        c0 = jnp.zeros((B, self.fwd.hidden_dim))

        if not deterministic and jax_rng is not None:
            # Pre-split RNGs for each time step
            rng_fwd, rng_bwd = jax.random.split(jax_rng)
            rngs_fwd = jax.random.split(rng_fwd, T)
            rngs_bwd = jax.random.split(rng_bwd, T)

        else:
            rngs_fwd = [None] * T
            rngs_bwd = [None] * T

        def fwd_scan(carry, x_and_rng):
            x_t, rng_t = x_and_rng
            return self.fwd(carry, x_t, deterministic=deterministic, jax_rng=rng_t)

        def bwd_scan(carry, x_and_rng):
            x_t, rng_t = x_and_rng
            return self.bwd(carry, x_t, deterministic=deterministic, jax_rng=rng_t)

        xs_fwd = (inputs.swapaxes(0, 1), rngs_fwd)
        xs_bwd = (jnp.flip(inputs, axis=1).swapaxes(0, 1), rngs_bwd)

        _, hs_fwd = jax.lax.scan(fwd_scan, (h0, c0), xs_fwd)
        hs_fwd = hs_fwd.swapaxes(0, 1)

        _, hs_bwd_rev = jax.lax.scan(bwd_scan, (h0, c0), xs_bwd)
        hs_bwd = jnp.flip(hs_bwd_rev.swapaxes(0, 1), axis=1)

        return hs_fwd, hs_bwd, jnp.concatenate([hs_fwd, hs_bwd], axis=-1)


class StackedBiLSTM(nnx.Module):
    """
    A stack of bidirectional LSTM layers implemented with :class:`BiLSTMLayer`.

    Each layer receives the concatenated hidden states of its predecessors
    (``h_fwd`` || ``h_bwd``) as input.

    Parameters
    ----------
    input_dim : int
        Dimensionality of the raw token embeddings fed to the first
        :class:`BiLSTMLayer`.  The dimensionality of all following layers
        will be `2 * hidden_dim`.
    hidden_dim : int
        Hidden size of each unidirectional LSTM within every
        :class:`BiLSTMLayer`.  The actual output of a layer
        has shape ``(B, T, 2 * hidden_dim)``.
    num_layers : int
        Number of stacked bidirectional layers.
    dropout : float, default 0.0
        Dropout probability applied inside each :class:`LSTMCell`.
    rngs : nnx.Rngs | None, default ``None``
        Random number generators

    Returns
    ------- 
      1. ``outs`` - A list containing the *raw input* followed by the
         concatenated output of each layer.  Hence ``len(outs) ==
         num_layers + 1`` and the last element has shape
         ``(B, T, 2 * hidden_dim)``.
      2. ``fwd_states`` - A list of the forward hidden states from each
         layer (shape ``(B, T, hidden_dim)``).
      3. ``bwd_states`` - A list of the backward hidden states from each
         layer (shape ``(B, T, hidden_dim)``).

    """
    def __init__(self, input_dim, hidden_dim, num_layers, dropout=0.0, *, rngs=None):
        self.layers = nnx.List([
            BiLSTMLayer(input_dim if i == 0 else 2*hidden_dim, hidden_dim, dropout=dropout, rngs=rngs)
            for i in range(num_layers)
        ])

    def __call__(self, x, deterministic=False, jax_rng=None):
        outs = [x]
        fwd_states, bwd_states = [], []

        if not deterministic and jax_rng is not None:
            rngs = jax.random.split(jax_rng, len(self.layers))
        else:
            rngs = [None] * len(self.layers)

        for layer, r in zip(self.layers, rngs):
            fwd, bwd, x = layer(x, deterministic=deterministic, jax_rng=r)
            fwd_states.append(fwd)
            bwd_states.append(bwd)
            outs.append(x)
        return outs, fwd_states, bwd_states


class LMHead(nnx.Module):
    """
    Language‑model output head that projects hidden states to logits over a
    target vocabulary.

    The module consists of a single linear transformation that maps the
    last hidden dimension of the model to a vector of size ``vocab_size``.

    Parameters
    ----------
    hidden_dim : int
        Dimensionality of the input hidden states ``h``.
    vocab_size : int
        Size of the target vocabulary (number of output logits).
    rngs : nnx.Rngs
        Random number generators

    Returns
    -------
    jax.numpy.ndarray
        Logits of shape ``(B, T, vocab_size)``

    """
    def __init__(self, hidden_dim, vocab_size, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(hidden_dim, vocab_size, rngs=rngs)
    def __call__(self, h):
        return self.linear(h)  # [B,T,V]

We define our ELMo model by initializing and chaining together the subcomponents:

  1. Character Level CNN
  2. Bidirectional LSTM
  3. Language Modeling Heads

In addition, we also initialize ELMo’s scalar mix parameters that are used to adapt the embeddings during fine-tuning on downstream specific tasks, and we define model regularization in the form of dropout. Our ELMo model has three layers of dropout, at the input layer, the LSTM layer, and the output layer.

In our implementation you’ll notice we project outputs into a common dimension to ensure the dimensional correctness for matrix multiplication, and you’ll also notice that the method forward_embeddings is only used during fine-tuning to tune the scalar mix parameters on downstream specific tasks.

Important

Output dropout regularizes language-model training, not downstream embeddings. Scalar-mixed embeddings themselves are not dropout-regularized in this implementation.

class ElmoModel(nnx.Module):
    """
    The model implements the core components of the original ELMo
    architecture: a character‑level CNN that produces sub‑word
    representations, a stack of bidirectional LSTMs that encode the
    sentence, and forward/backward language‑model heads. A
    *scalar‑mix* (parameterised by a softmax over learnable weights)
    combines the character‑CNN output and every LSTM layer into a
    fixed‑dimensional semantic vector (`common_dim`).  This vector can
    be used as contextualised word embeddings downstream.

    Parameters
    ----------
    char_vocab_size : int
        Vocabulary size for character indices.
    char_dim : int
        Size of the character embedding vectors.
    filters : Sequence[Tuple[int, int]]
        List of ``(num_filters, filter_width)`` tuples that define the
        convolutional channels in the character CNN.
    highway_layers : int
        Number of highway network layers in the character CNN.
    proj_dim : int
        Dimensionality of the output of the projection layer that comes
        after the character CNN.
    common_dim : int
        Dimensionality of the final ELMo embedding.
    hidden_dim : int
        Hidden state size of each BiLSTM cell.
    num_layers : int
        Number of stacked BiLSTM layers.
    word_vocab_size : int
        Size of the vocabulary for the forward and backward language‑model
        heads.
    input_dropout : float, default 0.1
        Dropout probability applied to the output of the character CNN
        during training.
    lstm_dropout : float, default 0.1
        Dropout probability applied within each BiLSTM layer during
        training.
    output_dropout : float, default 0.1
        Dropout probability applied to the top‑layer LSTM states before
        they are fed to the language‑model heads.
    rngs : nnx.Rngs
        JAX random number generator state used to initialise parameters.

    Notes
    -----
    * The character‑CNN (`self.char_cnn`) maps the raw character IDs to
      a vector of dimensionality ``proj_dim``.  This vector is then
      projected to the common embedding space (`common_dim`) by a
      linear layer.
    * Each BiLSTM layer produces a forward state of shape
      `(batch, seq_len, hidden_dim)` and a backward state of the
      same shape.  The states of a layer are concatenated along the
      feature dimension and projected to ``common_dim``.
    * The scalar mix treats the character‑CNN output as layer 0 and
      each BiLSTM layer as a subsequent layer.  The weight vector
      (`self.scalar_weights`) is soft‑maxed so that the weights sum
      to 1. The result is scaled by the learnable `gamma` parameter.
    * During evaluation (``deterministic=True``) drop‑outs are disabled
      and the same provided RNG is used to keep the computation
      deterministic.

    Methods
    -------
    forward_backbone(char_ids, jax_rng, deterministic=True)
        Compute the character embeddings, forward and backward LSTM states.

        Returns
        -------
        char_embs : jnp.ndarray
            The raw output of the character CNN, shape
            `(batch, seq_len, proj_dim)`.
        fwd_states : List[jnp.ndarray]
            Forward LSTM states for each of the ``num_layers`` layers,
            each of shape `(batch, seq_len, hidden_dim)`.
        bwd_states : List[jnp.ndarray]
            Backward LSTM states for each layer, each of shape
            `(batch, seq_len, hidden_dim)`.

    forward_logits(char_ids, jax_rng, deterministic=True)
        Return the forward and backward language‑model logits together
        with the intermediate representations.

        Returns
        -------
        fwd_logits : jnp.ndarray
            Forward language‑model logits, shape
            `(batch, seq_len, word_vocab_size)`.
        bwd_logits : jnp.ndarray
            Backward language‑model logits (time‑reversed), same shape as
            ``fwd_logits``.
        char_embs : jnp.ndarray
            Raw character‑CNN embeddings (as in ``forward_backbone``).
        fwd_states : List[jnp.ndarray]
            Forward LSTM states.
        bwd_states : List[jnp.ndarray]
            Backward LSTM states.

    forward_embeddings(char_embs, fwd_states, bwd_states)
        Produce the final contextualised embedding vector.

        Returns
        -------
        x : jnp.ndarray
            Contextualised ELMo embedding of shape
            `(batch, seq_len, common_dim)`.  It is a weighted sum of the
            projected character‑CNN output and each concatenated
            forward/backward LSTM layer, scaled by `gamma`.

    """
    def __init__(self, char_vocab_size, char_dim, filters, highway_layers,
                 proj_dim, common_dim, hidden_dim, num_layers, word_vocab_size,
                 input_dropout=0.1, lstm_dropout=0.1, output_dropout=0.1, *,
                 rngs: nnx.Rngs):
        # Submodules
        self.char_cnn = CharCNN(char_vocab_size, char_dim, filters, highway_layers, proj_dim=proj_dim, rngs=rngs)
        self.bilstm = StackedBiLSTM(proj_dim, hidden_dim, num_layers, dropout=lstm_dropout, rngs=rngs)
        self.fwd_head = LMHead(hidden_dim, word_vocab_size, rngs=rngs)
        self.bwd_head = LMHead(hidden_dim, word_vocab_size, rngs=rngs)

        # Scalar mix for ELMo embeddings
        self.common_dim = common_dim
        self.scalar_weights = nnx.Param(jnp.zeros(num_layers + 1))
        self.gamma = nnx.Param(jnp.array(1.0))

        # Projection layers to common_dim
        self.layer_projections = nnx.List()

        # CharCNN output to common dim
        self.layer_projections.append(
            nnx.Linear(proj_dim, common_dim, rngs=rngs)
        )

        # BiLSTM layers (2 * hidden_dim) to common_dim
        for _ in range(num_layers):
            self.layer_projections.append(
                nnx.Linear(2 * hidden_dim, common_dim, rngs=rngs)
            )

        assert len(self.layer_projections) == len(self.scalar_weights.value)

        # Dropout layers
        self.input_dropout = nnx.Dropout(rate=input_dropout, rngs=rngs)
        self.output_dropout = nnx.Dropout(rate=output_dropout, rngs=rngs)

    def forward_backbone(self, char_ids, jax_rng, deterministic: bool = True):
        char_embs = self.char_cnn(char_ids)

        if not deterministic:
            assert jax_rng is not None
            rng_in, rng_lstm = jax.random.split(jax_rng)
            x = self.input_dropout(char_embs, rngs=rng_in)
        else:
            x = char_embs
            rng_lstm = jax_rng

        _, fwd_states, bwd_states = self.bilstm(x, deterministic=deterministic, jax_rng=rng_lstm)

        return char_embs, fwd_states, bwd_states

    def forward_logits(self, char_ids, jax_rng, deterministic: bool = True):
        char_embs, fwd_states, bwd_states = self.forward_backbone(
            char_ids, deterministic=deterministic, jax_rng=jax_rng
        )

        top_fwd = fwd_states[-1]
        top_bwd = bwd_states[-1]

        top_fwd = self.output_dropout(top_fwd)
        top_bwd = self.output_dropout(top_bwd)

        fwd_logits = self.fwd_head(top_fwd)
        bwd_logits = jnp.flip(
            self.bwd_head(jnp.flip(top_bwd, axis=1)), axis=1
        )

        return fwd_logits, bwd_logits, char_embs, fwd_states, bwd_states

    def forward_embeddings(self, char_embs, fwd_states, bwd_states):
        layers = [char_embs] + [
            jnp.concatenate([fwd, bwd], axis=-1)
            for fwd, bwd in zip(fwd_states, bwd_states)
        ]

        w = jax.nn.softmax(self.scalar_weights.value)

        projected = [
            proj(layer) for proj, layer in zip(self.layer_projections, layers)
        ]

        x = sum(w_i * p for w_i, p in zip(w, projected))
        x = self.gamma.value * x

        return x

Loss Function

We have already discussed the learning objective, below is an implementation of a masked cross entropy loss function that masks padding from the loss computation.

Code
def masked_cross_entropy(logits, targets, pad_id=0):
    """
    Computes the average cross‑entropy loss for a batch while ignoring
    padding tokens.

    The function applies a *mask* to the per‑token loss so that any token
    whose target index equals ``pad_id`` is dropped from the loss
    calculation.

    Parameters
    ----------
    logits : jnp.ndarray
        Logits produced by the model.  Expected shape
        ``(batch, seq_len, vocab_size)``.
    targets : jnp.ndarray
        Ground‑truth token indices.  Expected shape ``(batch, seq_len)``
        with integer values in ``[0, vocab_size)``.  Positions that
        contain ``pad_id`` are treated as padding.
    pad_id : int, default=0
        The integer value used to mark padding positions in ``targets``.

    Returns
    -------
    loss : float
        The mean cross‑entropy loss over all non‑padding tokens in the
        batch.  The denominator is the total number of non‑padding
        tokens plus a small constant ``1e-12`` to avoid division by
        zero.

    """
    vocab_size = logits.shape[-1]
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    targets_onehot = jax.nn.one_hot(targets, vocab_size)
    per_token_loss = -jnp.sum(targets_onehot * log_probs, axis=-1)
    mask = (targets != pad_id).astype(jnp.float32)
    return jnp.sum(per_token_loss * mask) / (jnp.sum(mask) + 1e-12)

Define Training Loop

Next, we define the training loop, which encompasses the standard steps performed for each batch: executing the forward pass, computing gradients, and updating model parameters. This requires initializing the model and optimizer, configuring checkpointing to persist training state, and iterating over the training step for the desired number of epochs. These components follow conventional neural network training practice and do not introduce any ELMo-specific complexity.

There are, however, two practical considerations worth noting:

  1. We intentionally select hyperparameters corresponding to a reduced ELMo configuration, as the full-scale model is computationally infeasible on the available hardware.
  2. We limit pre-training to five epochs for the same reason.

Despite these constraints, even this abbreviated pre-training regimen yields clearly observable improvements on downstream tasks, as demonstrated in subsequent sections.

Code
@jax.jit
def train_step(optimizer, batch, jax_rng):
    """
    Performs a single optimisation update on the ELMo‐style model.

    The function runs a forward pass of the model to obtain forward and
    backward language‑model logits, computes the masked cross‑entropy
    loss for each direction, sums the two losses, back‑propagates the
    gradients, and finally applies the optimiser's update rule.

    Parameters
    ----------
    optimizer : nnx.optim.Optimizer
        An *nnx.optim.Optimizer* that owns the model to be trained.
        The optimiser must expose a ``model`` attribute and provide an
        ``update`` method that accepts the gradient dictionary.
    batch : Mapping[str, jnp.ndarray]
        A batch of training data mapping the following keys to
        integer arrays:
        * ``"char_ids"``: shape ``(batch, seq_len, word_len)``
          containing character indices for each token.
        * ``"target_ids"``: shape ``(batch, seq_len)``
          containing the target token indices for the language‑model head.
    jax_rng : jax.random.PRNGKey
        Random number generator used for dropout and other stochastic
        components of the model.

    Returns
    -------
    optimizer : nnx.optim.Optimizer
        The optimiser after the update.  It holds the freshly
        updated model parameters.
    loss : float
        The scalar loss value that was optimised.  It is the sum of the
        forward and backward masked cross‑entropy losses.

    Notes
    -----
    * The optimiser should store the model in the ``optimizer.model``
      attribute.  After the update, modifications to ``optimizer.model``
      reflect the new parameters.

    """
    char_ids = batch["char_ids"]
    targets  = batch["target_ids"]

    def loss_fn(model):
        fwd_logits, bwd_logits, _, _, _ = model.forward_logits(
            char_ids, deterministic=False, jax_rng=jax_rng
        )

        fwd_loss = masked_cross_entropy(
            fwd_logits[:, :-1, :], targets[:, 1:]
        )

        bwd_loss = masked_cross_entropy(
            bwd_logits[:, 1:, :], targets[:, :-1]
        )

        return fwd_loss + bwd_loss

    loss = loss_fn(optimizer.model)
    grads = nnx.grad(loss_fn)(optimizer.model)
    optimizer.update(grads)

    return optimizer, loss
Code
# RNG setup
rngs = nnx.Rngs(0)

# Hyperparameters
char_vocab_size = len(char_to_id)
char_dim = 16
filters = [
    (1, 64),
    (2, 128),
    (3, 256),
    (4, 256),
    (5, 256),
    (6, 256),
]
highway_layers = 2
proj_dim = 512
hidden_dim = 512
num_layers = 2

word_vocab_size = len(vocab)
common_dim = 512

# Model instantiation
model = ElmoModel(
    char_vocab_size=char_vocab_size,
    char_dim=char_dim,
    filters=filters,
    highway_layers=highway_layers,
    proj_dim=proj_dim,
    common_dim=common_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    word_vocab_size=word_vocab_size,
    input_dropout=0.1, 
    lstm_dropout=0.3, 
    output_dropout=0.1,
    rngs=rngs
)

# Optimizer instantiation
tx = optax.adamw(1e-3, weight_decay=1e-2)
optimizer = nnx.ModelAndOptimizer(model, tx)
Code
class ELMoTrainer:
    """
    A training wrapper for an ELMo bidirectional language model.  
    It handles epoch‑wise training, validation
    loss computation, and an early‑stopping strategy.

    Parameters
    ----------
    optimizer : nnx.optim.Optimizer
        The optimiser that owns the model to be trained.  It must expose a
        ``model`` attribute (the trainable model) and an ``update`` method.
    patience : int, optional (default=3)
        Number of consecutive validation epochs without improvement on the
        loss after which training is stopped early and the best model
        parameters are restored.
    rng_seed : int, optional (default=0)
        Seed for the JAX random number generator.

    Attributes
    ----------
    optimizer : nnx.optim.Optimizer
        The optimiser being used.
    patience : int
        See ``patience`` above.
    best_loss : float
        Best validation loss observed so far.  Initialized to ``∞``.
    wait : int
        Number of epochs since the last improvement.
    best_params : nnx.State | None
        The model parameters that produced the best validation loss.
        Stored as a JAX ``Mutable`` state so that they can be copied back
        into ``optimizer.model`` on early‑stopping.
    jax_rng : jax.random.PRNGKey
        Current random key.

    Methods
    -------
    train_epoch(train_loader)
        Runs one epoch of training on ``train_loader``.
    validate(model, val_loader)
        Computes the average loss over ``val_loader``.
    validate_and_stop(val_loader)
        Performs validation, logs results and checks the early‑stopping
        criterion.  Returns ``True`` if training should stop.
    """
    def __init__(self, optimizer, patience=3, rng_seed=0):
        self.optimizer = optimizer
        self.patience = patience
        self.best_loss = float("inf")
        self.wait = 0
        self.best_params = None
        self.jax_rng = jax.random.PRNGKey(rng_seed)

    def train_epoch(self, train_loader):
        total_loss = 0.0
        n = 0
        for batch in train_loader:
            self.jax_rng, subkey = jax.random.split(self.jax_rng)
            self.optimizer, loss = train_step(self.optimizer, batch, jax_rng=subkey)
            total_loss += float(loss)
            n += 1
        return total_loss / max(1, n)

    def validate(self, model, val_loader):
        total_loss = 0.0
        n = 0

        for batch in val_loader:
            targets  = batch["target_ids"]

            fwd_logits, bwd_logits, _, _, _ = model.forward_logits(
                batch["char_ids"], jax_rng=self.jax_rng
            )

            # Forward predicts t+1
            fwd_loss = masked_cross_entropy(fwd_logits[:, :-1, :], targets[:, 1:])

            # Backward predicts t-1
            bwd_loss = masked_cross_entropy(bwd_logits[:, 1:, :], targets[:, :-1])

            loss = fwd_loss + bwd_loss

            total_loss += float(loss)
            n += 1

        mean_loss = total_loss / max(1, n)
        return mean_loss


    def validate_and_stop(self, val_loader):
        val_loss = self.validate(self.optimizer.model, val_loader)
        print(f"  val_loss={val_loss:.4f}")

        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.best_params = nnx.state(self.optimizer.model)
            self.wait = 0
            print("  New best model saved.")
            return False  # continue training

        self.wait += 1
        print(f"  No improvement ({self.wait}/{self.patience})")

        if self.wait >= self.patience:
            print("Early stopping triggered!")
            nnx.update(self.optimizer.model, self.best_params)  # restore best params
            return True  # stop training

        return False
Code
# Where to save model checkpoints
ckpt_dir = "./checkpoints/elmo/state/"
checkpointer = ocp.StandardCheckpointer()
Caution

When creating your checkpoints, ensure you are checkpointing the state from the model that has the updated weights. In our case the one initialized in our trainer class tied to our optimizer and not the initialized model from above as those weights will not be updated.

Code
trainer = ELMoTrainer(optimizer, patience=100) # No need to early stop with pre-training

train_ds = load_from_disk("c4_train")
val_ds = load_from_disk("c4_val")

epochs = 5

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")

    # reset streaming ds
    train_loader = StreamingTextDataLoader(train_ds, vocab, char_to_id,
                                    seq_len=128, word_len=50,
                                    batch_size=20, shuffle_buffer=2048)

    train_loss = trainer.train_epoch(train_loader)

    if (epoch + 1) % 2 == 0:
        # reset streaming ds
        val_loader = StreamingTextDataLoader(val_ds, vocab, char_to_id,
                                        seq_len=128, word_len=50,
                                        batch_size=20, shuffle_buffer=2048)

        stop = trainer.validate_and_stop(val_loader)

    # Save checkpoint each epoch
    _, state = nnx.split(trainer.optimizer.model) # Make sure you use the model with the updated weights not the initialized model from above
    checkpointer.save(
        os.path.abspath(
            os.path.join(ckpt_dir, f"epoch{ epoch + 1 }")
        ), 
        state
    )

    if stop:
        break

Evaluate Learned Embeddings on Downstream Task

Okay, now that we have a pre-trained ELMo model on hand we are going to fine-tune it for a text classification task. For this task, we are going to use the Stanford Sentiment Treebank v2 (SST‑2) dataset.

Compare Random Weights to Pretrained Model

We are going to compare how a random weights initialized ELMo model performs in comparison to our breifly pre-trained ELMo model. Below we initialize a random model, very much the same way we initialized a model for pre-training. We also load our pre-trained weights from our saved checkpoint.

Code
# Initialize a random weights model

# RNG setup
rng = jax.random.PRNGKey(0)
rngs = nnx.Rngs(rng)

# Hyperparameters
char_vocab_size = len(char_to_id)
char_dim = 16
filters = [
    (1, 64),
    (2, 128),
    (3, 256),
    (4, 256),
    (5, 256),
    (6, 256),
]
highway_layers = 2
proj_dim = 512
hidden_dim = 512
num_layers = 2

word_vocab_size = len(vocab)
common_dim = 512

# Model instantiation
model_random = ElmoModel(
    char_vocab_size=char_vocab_size,
    char_dim=char_dim,
    filters=filters,
    highway_layers=highway_layers,
    proj_dim=proj_dim,
    common_dim=common_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    word_vocab_size=word_vocab_size,
    input_dropout=0.1, 
    lstm_dropout=0.3, 
    output_dropout=0.1,
    rngs=rngs
)
Code
# Load checkpointed model

# Construct an abstract version of the model (this is an empty scaffold so memory utilization is minimal)
abstract_model = nnx.eval_shape(
    lambda: ElmoModel(
        char_vocab_size=char_vocab_size,
        char_dim=char_dim,
        filters=filters,
        highway_layers=highway_layers,
        proj_dim=proj_dim,
        common_dim=common_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        word_vocab_size=word_vocab_size,
        input_dropout=0.1,
        lstm_dropout=0.3,
        output_dropout=0.1,
        rngs=nnx.Rngs(0)
    )
)

# Split to get graphdef and an abstract state
graphdef, abstract_state = nnx.split(abstract_model)

# Restore into that abstract state
ckpt_dir = "./checkpoints/elmo/state/"
epoch = 5
checkpointer = ocp.StandardCheckpointer()
restored_state = checkpointer.restore(
    os.path.abspath(os.path.join(ckpt_dir, f"epoch{epoch}"))
)

# Merge to produce a real model with pretrained weights
model_trained = nnx.merge(graphdef, restored_state)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
Tip

If you used a GPU for pre-training and you decide that you want to load your pre-trained weights onto a different device. You will need to map the state onto the new device. Below is a code cell that shows you how to do that.

Code
# To load the model that was trained using GPU onto a CPU only device use this:

# Ensure abstract_state is placed on the current local devices
cpu_device = jax.devices('cpu')[0]
sharding = jax.sharding.SingleDeviceSharding(cpu_device)

# Construct an abstract version of the model
abstract_model = nnx.eval_shape(
    lambda: ElmoModel(
        char_vocab_size=char_vocab_size,
        char_dim=char_dim,
        filters=filters,
        highway_layers=highway_layers,
        proj_dim=proj_dim,
        common_dim=common_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        word_vocab_size=word_vocab_size,
        input_dropout=0.1,
        lstm_dropout=0.3,
        output_dropout=0.1,
        rngs=nnx.Rngs(0)
    )
)

# Split to get graphdef and an abstract state
graphdef, abstract_state = nnx.split(abstract_model)

# Map the sharding onto your abstract state leaves
abstract_state = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding),
    abstract_state
)

# Initialize and restore
ckpt_dir = "./checkpoints/elmo/state/"
epoch = 5
checkpoint_path = os.path.abspath(os.path.join(ckpt_dir, f"epoch{epoch}"))

checkpointer = ocp.StandardCheckpointer()

# Pass abstract_state
restored_state = checkpointer.restore(
    checkpoint_path,
    abstract_state 
)

# Load into the model
model_trained = nnx.merge(graphdef, restored_state)

Classifier Architecture

Okey-dokes, let’s put together a classifier head to leverage our ELMo embeddings on our downstream text classification task. The sub-components for our network will be:

  1. The ELMo model as the backbone
  2. An attention pooling layer with explicit padding masking to collapse our seqeunces into a single representation
  3. A two-layer multilayer perceptron with dropout and ReLU activation yeilding our logits

It is worth noting that the attention pooling layer here is not the same as attention in transformer models. This layer simply learns a scalar scoring function over token embeddings.

class AttnPool(nnx.Module):
    """
    Attention‑based pooling layer that collapses a sequence of vectors into a
    single representation using a learnable attention weight.

    Parameters
    ----------
    dim : int
        Dimensionality of each input vector (``x.shape[-1]``).
    rngs : jax.random.PRNGKey
        Random number generator

    Returns
    -------
    jnp.ndarray
        A tensor of shape ``(batch_size, dim)`` containing the weighted
        sum of the input sequence.

    """
    def __init__(self, dim, *, rngs):
        self.proj = nnx.Linear(dim, 1, rngs=rngs)

    def __call__(self, x, mask=None):
        scores = self.proj(x).squeeze(-1)
        if mask is not None:
            scores = scores + (mask - 1) * 1e9
        weights = jax.nn.softmax(scores, axis=1)
        return jnp.sum(x * weights[..., None], axis=1)


class ElmoClassifier(nnx.Module):
    """
    Sequence classifier that builds on a pre‑trained ELMo backbone.

    The network follows the classic ELMo‑to‑text‑classification pipeline:

    1. **ELMo backbone** – A shared ELMo `ElmoModel` is used to generate
       contextual embeddings for each token (`char_ids`).
    2. **Mask‑aware attention pooling** – The token‑wise embeddings are
       weighted with an attention mechanism that respects the padding mask.
    3. **MLP classifier** – A two‑layer MLP with dropout and ReLU non‑linearity
       produces the final logits for *n_classes*.

    Parameters
    ----------
    elmo_model : ElmoModel
        Pre‑trained ELMo backbone that exposes two forward
        stages:
          * ``forward_backbone(char_ids, deterministic, jax_rng)``
            returns word‑level embeddings and forward/backward LSTM states.
          * ``forward_embeddings(char_embs, fwd_states, bwd_states)``
        The backbone must expose ``common_dim`` - the dimensionality of the ELMo
        embeddings that the classifier consumes.
    num_classes : int
        Number of target classes for the downstream classification task.
    dropout_rate : float, default 0.1
        Dropout probability applied before and after the hidden MLP layer.
    rngs : nnx.Rngs
        Random number generators

    Attributes
    ----------
    backbone : ElmoModel
        Reference to the ELMo backbone used for feature extraction.
    dropout : nnx.Dropout
        Dropout layer applied to the pooled representation and to the hidden
        MLP output.
    attn_pool : AttnPool
        Attention pooling head that weights tokens based on the ELMo output.
    classifier_hidden : nnx.Linear
        First linear layer of the classification MLP.
    classifier : nnx.Linear
        Final linear layer producing logits of shape ``(B, num_classes)``.

    Forward Pass
    ------------
    The module expects a 3‑D integer array of character IDs
    ``char_ids`` with shape ``(B, T, C)`` where:

    * **B** - batch size
    * **T** - sequence length (token count for each example)
    * **C** - number of character embeddings per token

    **Deterministic flag** – When ``deterministic=True`` the dropout
    layers are disabled

    **Masking** – A binary mask is inferred by checking for rows of all
    zeros in ``char_ids`` (treated as padding).  The mask is added to the
    raw attention scores before softmax, ensuring that padded positions
    receive negligible weight.

    Returns
    -------
    logits : jnp.ndarray
        Unnormalised class scores with shape ``(B, num_classes)``.


    The returned ``logits`` can be fed to a standard cross‑entropy loss
    during training.

    Notes
    -----
    * The attention pooling performs **softmax over the sequence dimension**
      and uses a large negative constant to mask out padding before softmax
      (effectively treating those positions as having negligible weight),
      which is a stable and differentiable alternative to masking in the
      exponent step.
    * The model relies on the ELMo backbone providing *common_dim*‑dimensional
      embeddings.  If the backbone uses a different dimensionality, the
      attributes and the MLP width need adjustment accordingly.
    """
    def __init__(self, elmo_model: ElmoModel, num_classes: int, dropout_rate: float = 0.1, *, rngs: nnx.Rngs):
        self.backbone = elmo_model
        self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
        self.attn_pool = AttnPool(self.backbone.common_dim, rngs=rngs)
        self.classifier_hidden = nnx.Linear(self.backbone.common_dim, self.backbone.common_dim, rngs=rngs)
        self.classifier = nnx.Linear(self.backbone.common_dim, num_classes, rngs=rngs)

    def __call__(self, char_ids, deterministic: bool = False, jax_rng=None):
        char_embs, fwd_states, bwd_states = self.backbone.forward_backbone(char_ids, deterministic=deterministic, jax_rng=jax_rng)
        elmo_embs = self.backbone.forward_embeddings(char_embs, fwd_states, bwd_states)

        mask = jnp.logical_not(jnp.all(char_ids == 0, axis=-1)).astype(jnp.float32)  # [B, T]
        pooled = self.attn_pool(elmo_embs, mask)

        x = self.dropout(pooled, deterministic=deterministic)
        x = jax.nn.relu(self.classifier_hidden(x))
        x = self.dropout(x, deterministic=deterministic)
        logits = self.classifier(x)
        return logits

Prepare Dataset

We need to build a new data loader. The current one was designed for language‑modeling pre‑training, but our goal is now text classification. Therefore, we must produce character and word indices that operate on whole sentences, not on a sliding window.

Code
hf_ds = load_dataset("glue", "sst2")
train_stream = hf_ds["train"]
val_stream = hf_ds["validation"]

# DataLoader setup
batch_size = 256
seq_len = 64
word_len = 50

def encode_batch(texts, vocab, char_to_id, seq_len, word_len):
    """
    Encode a batch of raw text strings into fixed‑size integer tensors.

    Words and characters that are unseen in the supplied dictionaries are
    replaced with the special unknown token.  Sequences longer than the
    requested limits are truncated, while shorter ones are padded with the
    special padding token.

    Parameters
    ----------
    texts : Iterable[str]
        A batch of raw text strings.  Each
        element is split on whitespace to produce a list of words.
    vocab : Mapping[str, int]
        Word‑to‑ID dictionary.  Must contain the special tokens ``"<pad>"``
        and ``"<unk>"``;
    char_to_id : Mapping[str, int]
        Character‑to‑ID dictionary.  It must contain ``"<pad>"`` and
        ``"<unk>"`` for padding and unknown characters respectively.
    seq_len : int
        Maximum number of words per sentence that will be encoded.  All
        sentences are truncated to this length or padded with the word
        padding token.
    word_len : int
        Maximum number of characters per word that will be encoded.  Words
        longer than this length are truncated; shorter words are padded
        with the character padding token.

    Returns
    -------
    dict
        A dictionary with two keys:
        ``"word_ids"`` : numpy.ndarray of shape ``(batch_size, seq_len)``
        ``"char_ids"`` : numpy.ndarray of shape ``(batch_size, seq_len,
        word_len)``

        * ``word_ids[i, j]`` holds the ID of the *j*-th word in the
          *i*-th input string;  ``vocab[PAD_WORD]`` if the position is
          padded.
        * ``char_ids[i, j, k]`` holds the ID of the *k*-th character of
          the *j*-th word in the *i*-th input string;  ``char_to_id[PAD_CHAR]``
          if the position is padded.

    """
    PAD_WORD = "<pad>"
    UNK_WORD = "<unk>"
    PAD_CHAR = "<pad>"
    UNK_CHAR = "<unk>"

    batch_size = len(texts)
    word_ids = np.full((batch_size, seq_len), vocab.get(PAD_WORD), dtype=np.int32)
    char_ids = np.full((batch_size, seq_len, word_len), char_to_id[PAD_CHAR], dtype=np.int32)

    for i, text in enumerate(texts):
        toks = text.split()[:seq_len]
        # Encode word IDs
        wid = [vocab.get(w, vocab.get(UNK_WORD)) for w in toks]
        word_ids[i, :len(wid)] = wid
        # Encode char IDs
        for j, w in enumerate(toks):
            cids = [char_to_id.get(c, char_to_id[UNK_CHAR]) for c in w[:word_len]]
            char_ids[i, j, :len(cids)] = cids

    return {"word_ids": word_ids, "char_ids": char_ids}


def sst2_loader(train_ds, vocab, char_to_id, seq_len, word_len, batch_size):
    """
    Yield batched, encoded SST‑2 dataset.

    The function takes a HuggingFace ``datasets.Dataset`` containing the
    Stanford Sentiment Treebank v2 (SST‑2) data, encodes the textual
    component into word‑ and character‑ids, and yields a Python generator
    that returns a dictionary of JAX arrays for each mini‑batch.

    Parameters
    ----------
    train_ds : :class:`datasets.Dataset`
        A HuggingFace ``Dataset`` object that must contain at least two
        columns:
        ``"sentence"`` – raw text data (a list of strings)
        ``"label"``   – integer labels (0: negative, 1: positive)
    vocab : Mapping[str, int]
        Word‑to‑ID vocabulary.  Must contain the special tokens
        ``"<pad>"`` and ``"<unk>"`` used by :func:`encode_batch`.
    char_to_id : Mapping[str, int]
        Character‑to‑ID mapping.  Must contain the special tokens
        ``"<pad>"`` and ``"<unk>"``.
    seq_len : int
        Maximum number of words per sentence.  Sentences longer than this
        limit will be truncated; shorter ones padded to ``seq_len``.
    word_len : int
        Maximum number of characters per word.  Characters longer than
        this limit are truncated; shorter ones padded to ``word_len``.
    batch_size : int
        Number of examples per yielded batch.

    Yields
    ------
    dict
        A dictionary with the following JAX array entries (dtype
        ``jnp.int32``):
        ``"char_ids"`` : shape ``(batch_size, seq_len, word_len)``
        ``"word_ids"`` : shape ``(batch_size, seq_len)``
        ``"labels"``  : shape ``(batch_size,)``

    """
    ds = train_ds.shuffle()
    for i in range(0, len(ds), batch_size):
        batch = ds[i:i + batch_size]

        # Each field is a list
        texts = batch["sentence"]
        labels = batch["label"]

        # Encode text to char IDs
        enc = encode_batch(texts, vocab, char_to_id, seq_len, word_len) # These are already batch sized

        yield {
            "char_ids": jnp.array(enc["char_ids"]),
            "word_ids": jnp.array(enc["word_ids"]),
            "labels": jnp.array(labels, dtype=jnp.int32),
        }

train_loader = sst2_loader(
    train_stream,
    vocab=vocab,
    char_to_id=char_to_id,
    seq_len=seq_len,
    word_len=word_len,
    batch_size=batch_size,
)

Phased Fine-Tuning

We fine‑tune the model in two distinct stages in order to avoid large gradients destroying the pre‑trained representation.

Phase 1 as a “High‑level” adaptation:

  • All parameters of the ELMo encoder, the bi‑LSTM and its learned representations, are frozen.

  • The only trainable parameters are the scalar‑mix weights, which blend the encoder layers, and the classifier head.

Phase 2 as a “Deep” adaptation:

  • The bi‑LSTM parameters are now unfrozen so that the encoder can adjust its internal representations to the target task.

  • The scalar‑mix weights and the classifier head remain trainable.

  • A very small learning rate is employed to limit catastrophic forgetting.

The following code block explicitly lists the parameter groups updated in each phase.

# Fine-tuning phase 1 updates the parameters of the classifier and the scalar mix parameters in the backbone
trainable_phase1 = nnx.All(
    nnx.Param,
    nnx.Any(
        nnx.PathContains("layer_projections"),
        nnx.PathContains("scalar_weights"),
        nnx.PathContains("gamma"),
        nnx.PathContains("attn_pool"),
        nnx.PathContains("classifier"),
    )
)

# Fine-tuning phase 2 unfreezes the biltsm layers in the backbone
trainable_phase2 = nnx.All(
    nnx.Param,
    nnx.Any(
        nnx.PathContains("bilstm"),
        nnx.PathContains("layer_projections"),
        nnx.PathContains("scalar_weights"),
        nnx.PathContains("gamma"),
        nnx.PathContains("attn_pool"),
        nnx.PathContains("classifier"),
    )
)

Define Training Loop

Next, we set up the training step, initialize the classifier and optimizer, and run the training loop. To keep the gradient updates confined to the intended parameters, we pass our trainable_phase object to create a DiffState.

In the first phase we instantiate two copies of the model:

  1. one with random‑initialized ELMo backbone weights, and
  2. one with pre‑trained ELMo weights.

This dual run lets us directly compare performance and confirm that our pre‑training regime is effective. We run phase-1 for 10 epochs.

Below you will notice the large difference in performance between the random weights backbone and the pre-trained backbone; Suggesting that our pre-training regime was indeed effective.

Code
@nnx.jit(static_argnames=("trainable_phase",))
def train_step(model_opt, batch, rng, *, trainable_phase):
    """
    Perform one training step for a JAX/NNX model using a custom optimizer.

    Parameters
    ----------
    model_opt: OptimizerWrapper
        A lightweight optimizer object that holds the current model
        parameters 

    batch: dict[str, jnp.ndarray]
        A batch dictionary produced by :func:`sst2_loader`.

    rng: jax.random.PRNGKey
        Randon number generator

    trainable_phase
        tells :class:`nnx.DiffState` which part of the model should
        get gradients.

    Returns
    -------
    tuple
        ``(model_opt, loss, acc, grads)``

        * ``model_opt`` - the :class:`OptimizerWrapper` after applying
          the gradient update.
        * ``loss`` (float)  - the mean soft‑max cross‑entropy
          computed on the current batch.
        * ``acc`` (float)   - accuracy of the model on this batch.
        * ``grads``         - a PyTree of gradients with the same
          structure as ``model_opt.model``.

    """

    # DiffState must match optimizer wrt argument
    diff_state = nnx.DiffState(0, trainable_phase)

    def loss_fn(model):
        logits = model(batch["char_ids"], deterministic=False, jax_rng=rng)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, batch["labels"]
        ).mean()
        return loss, logits

    (loss, logits), grads = nnx.value_and_grad(
        loss_fn,
        has_aux=True,
        argnums=diff_state,
    )(model_opt.model)

    model_opt.update(grads)

    preds = jnp.argmax(logits, axis=-1)
    acc = jnp.mean(preds == batch["labels"])

    return model_opt, loss, acc, grads


classifier_random = ElmoClassifier(
    model_random,
    num_classes=2,
    dropout_rate=0.3,
    rngs=nnx.Rngs(jax.random.PRNGKey(1))
)

classifier_trained = ElmoClassifier(
    model_trained,
    num_classes=2,
    dropout_rate=0.3,
    rngs=nnx.Rngs(jax.random.PRNGKey(1))
)
Code
tx = optax.adamw(1e-3, weight_decay=1e-2)

model_opt_random = nnx.ModelAndOptimizer(
    classifier_random,
    tx,
    wrt=trainable_phase1
)

model_opt_trained = nnx.ModelAndOptimizer(
    classifier_trained,
    tx,
    wrt=trainable_phase1
)
num_epochs = 10
train_rng = jax.random.PRNGKey(0)

for epoch in range(num_epochs):
    print(f"\n===== Epoch {epoch + 1}/{num_epochs} =====")

    # reinitialize or reshuffle dataset each epoch
    train_loader = sst2_loader(
        train_stream, vocab, char_to_id, seq_len, word_len, batch_size
    )

    epoch_loss_random = []
    epoch_loss_trained = []
    epoch_acc_random = []
    epoch_acc_trained = []

    for batch in train_loader:
        
        train_rng, subkey = jax.random.split(train_rng)
        model_opt_random, loss_random, acc_random, grads_random = train_step(model_opt_random, batch, rng=subkey, trainable_phase=trainable_phase1)
        model_opt_trained, loss_trained, acc_trained, grads_trained = train_step(model_opt_trained, batch, rng=subkey, trainable_phase=trainable_phase1)
        
        epoch_loss_random.append(float(loss_random))
        epoch_acc_random.append(float(acc_random))
        epoch_loss_trained.append(float(loss_trained))
        epoch_acc_trained.append(float(acc_trained))
        

    print(f"Epoch {epoch + 1} | loss={np.mean(epoch_loss_random):.4f} | acc={np.mean(epoch_acc_random):.4f}")
    print(f"Epoch {epoch + 1} | loss={np.mean(epoch_loss_trained):.4f} | acc={np.mean(epoch_acc_trained):.4f}")

===== Epoch 1/10 =====
Epoch 1 | loss=0.6870 | acc=0.5525
Epoch 1 | loss=0.6398 | acc=0.6273

===== Epoch 2/10 =====
Epoch 2 | loss=0.6799 | acc=0.5675
Epoch 2 | loss=0.5723 | acc=0.6964

===== Epoch 3/10 =====
Epoch 3 | loss=0.6665 | acc=0.5960
Epoch 3 | loss=0.5337 | acc=0.7277

===== Epoch 4/10 =====
Epoch 4 | loss=0.6487 | acc=0.6239
Epoch 4 | loss=0.5135 | acc=0.7440

===== Epoch 5/10 =====
Epoch 5 | loss=0.6328 | acc=0.6446
Epoch 5 | loss=0.4989 | acc=0.7517

===== Epoch 6/10 =====
Epoch 6 | loss=0.6181 | acc=0.6612
Epoch 6 | loss=0.4839 | acc=0.7643

===== Epoch 7/10 =====
Epoch 7 | loss=0.6073 | acc=0.6708
Epoch 7 | loss=0.4750 | acc=0.7704

===== Epoch 8/10 =====
Epoch 8 | loss=0.5989 | acc=0.6778
Epoch 8 | loss=0.4663 | acc=0.7777

===== Epoch 9/10 =====
Epoch 9 | loss=0.5911 | acc=0.6844
Epoch 9 | loss=0.4594 | acc=0.7807

===== Epoch 10/10 =====
Epoch 10 | loss=0.5864 | acc=0.6887
Epoch 10 | loss=0.4510 | acc=0.7863

Next, we extract the fine-tuned model from the optimizer and instantiate a new model_opt via nnx.ModelAndOptimizer(). Note that for phase two we reduce the learning rate to \(5 \times 10^{-5}\). This lower rate helps preserve the pretrained bidirectional LSTM weights while allowing for small, corrective updates. We then train for five additional epochs in this phase.

Code
tx = optax.adamw(5e-5, weight_decay=1e-2)

model_opt = nnx.ModelAndOptimizer(
    model_opt_trained.model,
    tx,
    wrt=trainable_phase2
)
Code
num_epochs = 5
train_rng = jax.random.PRNGKey(0)

for epoch in range(num_epochs):
    print(f"\n===== Epoch {epoch + 1}/{num_epochs} =====")

    # reinitialize dataset each epoch
    train_loader = sst2_loader(
        train_stream, vocab, char_to_id, seq_len, word_len, batch_size
    )

    epoch_loss = []
    epoch_acc = []

    for batch in train_loader:
        train_rng, subkey = jax.random.split(train_rng)
        model_opt, loss, acc, grads = train_step(model_opt, batch, rng=subkey, trainable_phase=trainable_phase2)
        epoch_loss.append(float(loss))
        epoch_acc.append(float(acc))

    print(f"Epoch {epoch + 1} | loss={np.mean(epoch_loss):.4f} | acc={np.mean(epoch_acc):.4f}")

===== Epoch 1/5 =====
Epoch 1 | loss=0.4292 | acc=0.8011

===== Epoch 2/5 =====
Epoch 2 | loss=0.4156 | acc=0.8091

===== Epoch 3/5 =====
Epoch 3 | loss=0.4039 | acc=0.8141

===== Epoch 4/5 =====
Epoch 4 | loss=0.3953 | acc=0.8215

===== Epoch 5/5 =====
Epoch 5 | loss=0.3860 | acc=0.8254

Finally, lets check the performance on the validation data.

val_loader = sst2_loader(
    val_stream,
    vocab=vocab,
    char_to_id=char_to_id,
    seq_len=seq_len,
    word_len=word_len,
    batch_size=batch_size,
)

acc = []
for batch in val_loader:
    raw_preds = model_opt.model(batch["char_ids"], deterministic=True)
    preds = np.argmax(raw_preds, axis=1)
    batch_acc = np.sum(batch["labels"] == preds) / preds.shape[0]
    acc.append(batch_acc)

print(f"Accuracy on validation: {np.mean(acc): .3f}")
Accuracy on validation:  0.742

Conclusion

This was a long one. Good job making it all the way to the end!

In this post we reviewed the differences beween static and contextual embeddings, we delved into the math and architecture of recurrent neural networks, and we implemented ELMo practically from scratch in JAX. In addition to all of that, we also then fine-tined ELMo to perform text classification and we showed the difference in having pre-trained embeddings versus random weights.

Coming Next

In the next post of this series, we will go over the transformer architecture, take a look at how attention works and implement a transformer model from scratch.

I hope to see you there!

References

Peters, Matthew E., Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, and Luke Zettlemoyer. 2018. “Deep Contextualized Word Representations.” In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), 2227–37. New Orleans, Louisiana: Association for Computational Linguistics. https://aclanthology.org/N18-1202.pdf.