flowchart BT X[$$x_t$$] -->|input| H[$$h_t$$] H -->|output| Y[$$y_t$$] H -->|recurrent state| H style H fill:#eef,stroke:#333,stroke-width:1.5px
Overview
We go over contextual embeddings and recurrent neural networks. Describing their inner workings and walking through an implementation of Embeddings from Language Models (ELMo) from scratch, using the JAX ecosystem.
Introduction
In previous posts we learned what embeddings are, why they are important and how they can be levaraged in practice. We got our hands dirty by implementing two of the earliest foundational methods GloVe and Skip-gram with negative sampling (SGNS).
If you need a refresher or just hadn’t seen them you can check out my posts on GloVe and SGNS.
In this post we will briefly review the differences between static and contextual embeddings and discuss some of the methodologies used to create contextual embeddings. Leading us to the architecture of recurrent neural networks and how they are leveraged in contextual embeddings.
Finally, we will delve into the architecture of Embeddings from Language Models (ELMo), and implement it from scratch.
From Static to Contextual Embeddings
As a reminder, a strength of contextual embeddings is the ability to discriminate between polysemous words. For example, the word “ship” can refer to the action verb of shipping an item or to the noun of a water vessel. A static embedding would represent both meanings of the word “ship” with the same vector. In contrast, contextual embeddings are able to disambiguate the two meanings of the same word.
Since our vectors now need to be aware of context, the sequence in which words occur must influence the resultant vectors. Importantly, not only during training must our vectors be aware of context but also during inference. Therefore, we must build vector embeddings dynamically (on the fly) during inference to ensure that this is the case. This contrasts with earlier methods like SGNS or GloVe:
SGNS trains a vector for each word by predicting its local context, but at inference it simply looks up the pre‑computed static vector—there’s no per‑sentence adaptation.
GloVe builds a word‑co‑occurrence matrix and factorizes it, again yielding one static vector per word that is reused for every sentence.
One way we include information from sequences is by using recurrent computation. This brings us to the topic of recurrent neural networks.
Recurrent Neural Networks
A recurrent neural network (RNN) is a neural architecture designed for sequential data, in which information from previous time steps is carried forward through a recurrent hidden state. At each time step, the hidden state is updated based on the current input and the previous hidden state, and the output is computed from this hidden representation.
We have two equations that define a simple RNN, also known as an Elman RNN.
\[ h_{t} = \phi(W_{h}h_{t-1} + W_{x}x_{t} + b) \]
\[ y_{t} = \psi(W_{y}h_{t} + b_{y}) \]
Where \(x_{t}\), \(h_{t}\), and \(y_{t}\) are the input, hidden state, and output at time t, respectively. The weight matrices \(W\) are linear operators that project variables from one space into another. For example, \(W_{x}\) projects \(x_{t}\) into the hidden state space. Notice that these weight matrices are not indexed by t, signifying that they are shared across time and allows RNNs to generalize to varying sequence lengths.
The Elman RNN maintains only one hidden state, at each time step the same computation must compress all past information into a single \(h_{t}\). This is challenging to do for longer time dependencies. Additionally, Elman RNNs are prone to vanishing/exploding gradients due to the gradients needing to repeatedly pass through \(W_{h}\) and the nonlinearity \(\phi\), via backpropagation through time. In practice, these limitations are enough to opt for a modified RNN architecture.
Long Short-Term Memory
A Long Short-Term Memory (LSTM) RNN is an architectural extension of the standard Elman RNN designed to mitigate the vanishing gradient problem when modeling long-range dependencies. The LSTM introduces gating mechanisms (input, forget, and output gates) that explicitly regulate how much past information is retained and how much new information is incorporated, rather than relying on implicit storage in the recurrent weights as in an Elman RNN. Crucially, the LSTM maintains a cell state whose update follows an additive structure, forming a linear path through time. These additive dynamics enables more stable gradient propagation and substantially reduces vanishing gradients during training.
Let’s dive a little bit deeper here. Given an input at time \(t\), \(x_{t}\), and a previous hidden state \(h_{t-1}\) and a previous cell state \(c_{t-1}\), we have the following equations that control the evolution of the hidden and cell states over time.
First, the forget gate \(f_{t}\) is a parametric control for the signal that should be propagated forward from the prior cell state \(c_{t-1}\) to the current cell state \(c_{t}\) that we are computing.
\[ f_{t} = \sigma(W_{f} \cdot [h_{t-1}, x_{t}] + b_{f}) \]
The input gate \(i_{t}\), you’ll notice, is almost the same computation as the forget gate and its role is to regulate how much signal should pass through from the candidate cell state \(\tilde{c}_{t}\) to the current cell state \(c_{t}\).
\[ i_{t} = \sigma(W_{i} \cdot [h_{t-1}, x_{t}] + b_{i}) \]
Our last gate, the output gate \(o_{t}\) is again the same computation as the other gates and its role is to regulate how much of the current cell state \(c_{t}\) is exposed to the current hidden state \(h_{t}\).
\[ o_{t} = \sigma(W_{o} \cdot [h_{t-1}, x_{t}] + b_{o}) \]
You may be thinking if all the computations are the same how do the different gates perform different functions? Well, that is a good question and one that is not obvious. However, their distinct roles emerge from how each gate modulates different computational paths in the forward pass, and how gradients propagate through those paths during backpropagation.
In the remaining equations you can see how these gates act as controls on the input, cell state, and hidden state.
\[ \tilde{c}_{t} = \tanh(W_{c} \cdot [h_{t-1}, x_{t}] + b_{c}) \]
\[ c_{t} = f_{t} \odot c_{t-1} + i_{t} \odot \tilde{c}_{t} \]
\[ h_{t} = o_{t} \odot \tanh(c_{t}) \]
Below is a graph that depicts the flow of computations that take place to update the hidden state.
flowchart BT
x_t["$$x_{t}$$ (input)"] --> concat
h_prev["$$h_{t-1}$$ (prev hidden)"] --> concat
concat["Concatenate"] --> f_gate["Forget Gate<br> $$\sigma(W_f \cdot [h_{t-1}, x_{t}] + b_f)$$"]
concat --> i_gate["Input Gate<br/> $$\sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$"]
concat --> c_tilde["Candidate State<br/> $$\tanh(W_c \cdot [h_{t-1}, x_t] + b_c)$$"]
concat --> o_gate["Output Gate<br/> $$σ(W_o \cdot [h_{t-1}, x_t] + b_o)$$"]
c_prev["$$c_{t-1}$$ (prev cell)"] --> mult_f["×"]
f_gate --> mult_f
i_gate --> mult_i["×"]
c_tilde --> mult_i
mult_f --> c_t["cₜ (cell state)"]
mult_i --> c_t
c_t --> tanh_c["$$\tanh$$"]
tanh_c --> mult_o["×"]
o_gate --> mult_o
mult_o --> h_t["$$h_t$$ (hidden state)"]
Embeddings from Language Models
Now that we better understanding of RNNs, specifically LSTMs, we are ready to talk about a deep learning architecture used to generate contextual embeddings. The influential Embeddings from Language Models (ELMo) architecture.
Before transformers became popular ELMo was one of the dominant models for generating contextual embeddings. ELMo blends character‑level convolutional neural networks (CNN) token encodings for handling out-of-vocabulary (OOV) words, multi‑layer bidirectional LSTMs to capture both left-to-right and right-to-left long range dependencies, and a trainable weighted sum of their hidden states to produce word‑level embeddings that are both context‑sensitive and adaptable through fine‑tuning.
Architecture
Let’s look at a high-level map of ELMo’s architecture. The important sub-components are:
- Character Level CNN: Converts each token into character-level feature vectors and is well suited to handling OOV words.
- Multi-layer bi-directional LSTM: processes the character vectors in both directions, producing hidden states that capture left and right‑context.
- Contextual Hidden State: combines forward and backward outputs to form a deep, context‑sensitive representation of each token.
- Scalar Mix parameters: Learn a task-specific weighted combination of the representations from different layers, enabling effective transfer of the pretrained language model to downstream tasks without retraining the full model.
Below is a high-level graph of an ELMo model.
flowchart BT
%% 1 Input
Tokens["$$\text{Token IDs } T_{1} \text{ to } T_{n}$$"] --> CharEmb[Character Embedding]
%% 2 Character CNN (h0)
CharEmb --> CharCNN[Character CNN Output]
CharCNN --> H0["$$h_{0} \text{ Character-based token representation}$$"]
%% 3 BiLSTM layers
subgraph BiLSTM[BiLSTM]
direction RL
Title[Bidirectional LSTM Stack]
style Title fill:none,stroke:none
subgraph Layer2[Layer 2]
F2[Forward LSTM 2] --> HF2[h_fwd_2]
B2[Backward LSTM 2] --> HB2[h_bwd_2]
end
subgraph Layer1[Layer 1]
F1[Forward LSTM 1] --> HF1[h_fwd_1]
B1[Backward LSTM 1] --> HB1[h_bwd_1]
end
end
H0 --> BiLSTM
%% 4 Contextual states
subgraph HiddenStates[Contextual Representations]
H1["$$h_{1} = \text{concat fwd1 bwd1}$$"]
H2["$$h_{2} = \text{concat fwd2 bwd2}$$"]
end
HF1 --> H1
HB1 --> H1
HF2 --> H2
HB2 --> H2
%% 5 Forward & Backward LM heads
subgraph LMHeads[LM Heads]
direction LR
LMTitle[Language Modeling Heads]
style LMTitle fill:none,stroke:none
FLM[Forward LM Head] --> FSoftmax["Softmax(vocab)"]
BLM[Backward LM Head] --> BSoftmax["Softmax(vocab)"]
end
HF1 --> FLM
HF2 --> FLM
HB1 --> BLM
HB2 --> BLM
%% 6 ELMo scalar mixer
subgraph Mixer[ELMo Scalar Mixer]
A0["$$a_{0}$$"] --> Softmax["$$\text{Softmax over } a_{i}$$"]
A1["$$a_{1}$$"] --> Softmax
A2["$$a_{2}$$"] --> Softmax
Softmax --> S0["$$s_{0}$$"]
Softmax --> S1["$$s_{1}$$"]
Softmax --> S2["$$s_{2}$$"]
Sum["$$\text{Sum }s_{i}h_{i}$$"]
Gamma[Gamma]
end
H0 --> Sum
H1 --> Sum
H2 --> Sum
S0 --> Sum
S1 --> Sum
S2 --> Sum
Sum --> Gamma
%% 7 ELMo output
subgraph ELMO[ELMo Embedding]
ELMOvec["ELMo(t)"]
end
Gamma --> ELMOvec
style LMHeads fill:#5eaabf,stroke:#5eaabf,stroke-width:2px
Pretraining Objective
ELMo uses a bidirectional language modeling objective to maximize the likelihood of observed tokens given all surrounding context. For ease of understanding, we can breakdown the objective into two directional components.
First, we have the forward language model objective which predicts each token \(x_{t}\) given all prior tokens.
\[ P_{fwd}(x) = \prod_{t=1}^{T} P(x_{t}|x_{1}...x_{t-1}) \]
Second, we have the backwards language model objective which as the name implies predicts each token \(x_{t}\) given all proceeding tokens.
\[ P_{bwd}(x) = \prod_{t=1}^{T} P(x_{t}|x_{T}...x_{t+1}) \]
The total pretraining objective is the sum of the forward and backward negative log-likelihoods:
\[ \mathcal{L} = - \sum_{t=1}^{T} \log P_{fwd}(x_t | x_{<t}) \;-\; \sum_{t=1}^{T} \log P_{bwd}(x_t | x_{>t}) \]
Hands-on Implementation
Okay then, time to get our hands dirty! We are going to build an ELMo model from scratch using the JAX deep learning ecosystem.
Specifically, we will pre-train the model using the aforementioned language modeling objective. Subsequently, we will fine-tune our scalar mix on a downstream task to evaluate our embeddings. I hope you enjoy!
In this part of the post, I’ll collapse all but the most important code cells, leaving only the ones I’ll discuss in depth open.
Processing Utilities
To get started we are going to need some toy data. I chose the c4 dataset provided by AllenAI because it had nice long textual examples that were relatively clean.
In order to keep things running on a local machine I took one million examples for training and 50,000 examples for validation. I personally streamed the data to my local disk for faster and easier of reinitialization. However, with the streaming dataset utility function we define you can also just stream the data directly from HuggingFace.
We also need to initialize our vocabulary, at the character level for our character level CNN and at the word level to feed into our language modeling head.
For our character level vocabulary we simply use the characters that you’d typically see in english text, in addition to an index for a padding character and an unknown character.
Our word vocabulary is specified very similarly to how we have done it in the previous posts. We build a word vocabulary of up to 100,000 words with a minimum word frequency of 20 words.
We use a simple tokenizer, splitting text on spaces, which is not canonically what ELMo utilizes in the original paper but works almost just as well.
Finally, as I eluded to earlier we need a way to stream batches of our data to the model as we are training. The utility defined below handles shuffling, tokenizing, and encoding both characters and words before yielding them to the model.
I have elected to use non-overlapping sequences because it is a lot faster and I am running on a local machine. For best results it would be best to use overlapping sequences. You can adjust this in the __iter__ method.
Code
import os
from functools import partial
from collections import Counter
import itertools
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import numpy as np
from datasets import load_dataset, DownloadConfig, Dataset, load_from_disk
import string
import random
import orbax.checkpoint as ocp
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()Code
# Stream data to local disk: one time only
# You can also just stream from HF, however on reload (i.e. next epoch) you need to stream again (re-download).
def train_gen():
"""
Streams english c4 training data from HF Allen AI. We subset the data to 1M examples.
"""
stream = load_dataset(
"allenai/c4",
"en",
split="train",
streaming=True,
)
for i, row in enumerate(stream):
if i >= 1_000_000:
break
yield row
def val_gen():
"""
Streams english c4 validatiom data from HF Allen AI. We subset the data to 50k examples.
"""
stream = load_dataset(
"allenai/c4",
"en",
split="train",
streaming=True,
)
for i, row in enumerate(stream):
if i < 2_000_000:
continue
if i >= 2_050_000:
break
yield row
# In batches of 1k save the data to disk
train_ds = Dataset.from_generator(
train_gen,
writer_batch_size=1_000,
)
train_ds.save_to_disk("c4_train")
val_ds = Dataset.from_generator(
val_gen,
writer_batch_size=1_000,
)
val_ds.save_to_disk("c4_val")Code
# Tokens used when padding or encountering unknown characters in a string.
PAD_CHAR, UNK_CHAR = "<pad>", "<unk>"
# Tokens used when padding or encountering unknown words in a sentence.
PAD_WORD, UNK_WORD = "<pad>", "<unk>"
# Characters that we expect to see in typical English text
char_vocab = [PAD_CHAR, UNK_CHAR] + list(string.ascii_lowercase + string.ascii_uppercase + string.digits + string.punctuation)
char_to_id = {ch: i for i, ch in enumerate(char_vocab)}
def build_word_vocab(dataset, batch_size=32, min_freq=20, max_vocab=100_000) -> tuple[dict[str, int], dict[int, str]]:
"""
Build a word–to–index mapping from a text dataset.
The function scans the provided dataset for tokenised words
(splitting on whitespace) and retains only words that appear at least
:param min_freq: times in the corpus. The returned dictionary is
capped to at most :param max_vocab: entries.
Parameters
----------
dataset
Iterable of records where each record is a mapping
(``dict``) that contains a ``"text"`` key holding the raw string.
batch_size
Number of records processed in a single counting pass.
min_freq
Minimum frequency required for a word to be included in the
resulting vocabularies.
max_vocab
Maximum number of words (excluding ``<pad>`` and ``<unk>``) to keep in
the vocabularies. The words chosen are the most frequent ones that
pass ``min_freq`` filtering.
Returns
-------
word_to_index : dict
Mapping from a word string to a unique integer ID, with
``<pad>`` mapping to 0 and ``<unk>`` mapping to 1.
index_to_word : dict
Inverse mapping from integer ID back to the corresponding word.
Notes
-----
* The tokenisation strategy is simplistic: it uses the standard
``str.split()`` which splits on any whitespace. For more advanced
tokenisation pipelines (e.g., handling sub‑words, hyphenated
compounds, etc.) replace the ``s.split()`` call accordingly.
* The frequency counter is updated in batches to keep the memory
footprint low; ``Counter.update`` is called repeatedly on
concatenated lists of tokens.
"""
counter = Counter()
buf = []
for row in dataset:
buf.append(row["text"])
if len(buf) >= batch_size:
for s in buf:
counter.update(s.split())
buf = []
if buf:
for s in buf:
counter.update(s.split())
# top words with min frequency
most_common = [(w, c) for w, c in counter.most_common(max_vocab) if c >= min_freq]
word_to_index = {PAD_WORD: 0, UNK_WORD: 1}
index_to_word = {0: PAD_WORD, 1: UNK_WORD}
for i, (w, _) in enumerate(most_common, start=2):
word_to_index[w] = i
index_to_word[i] = w
return word_to_index, index_to_wordCode
class StreamingTextDataLoader:
"""
A streaming data loader for text corpora that yields batches of tokenized
sequences.
This class supports streaming from an arbitrary dataset, automatically shuffling,
tokenizing into fixed‑length windows (with stride = seq_len), and encoding both
words and characters. It provides batched dictionaries containing ``word_ids``,
``char_ids`` and ``target_ids`` as JAX arrays.
Parameters
----------
ds: Iterable[dict]
A dataset yielding rows that contain a text field (e.g., "text", "sentence",
"review", or "content"). Each row should be a mapping from column name to value.
vocab: dict[str, int] | Vocab object
Mapping from token strings to unique integer IDs.
char_to_id: dict[str, int]
Mapping from characters to their integer IDs.
seq_len: int
Length of the input sequence window (number of tokens). Each yielded batch will contain sequences of this size.
word_len: int
Maximum length of each token when encoded at the character level. Tokens longer than this are truncated; shorter ones are padded with ``PAD_CHAR``.
batch_size: int
Number of examples to accumulate before yielding a batched dictionary.
shuffle_buffer: int, optional (default=2048)
Size of the buffer used for shuffling tokens from the stream. Larger buffers
provide better randomness at the cost of memory usage.
seed: int or None, optional (default=0)
Random seed for reproducibility when shuffling.
Attributes
----------
self.ds: original iterable dataset.
self.vocab: token vocabulary.
self.char_to_id: character‑to‑ID mapping.
self.seq_len: sequence length used for chunking.
self.word_len: maximum character width per token.
self.batch_size: number of samples per batch to yield.
self.shuffle_buffer: size of the shuffle buffer.
self.seed: RNG seed.
self.token_buffer: internal queue holding pre‑assembled tokens awaiting windowing.
Yields
------
dict
A dictionary with three JAX arrays:
``word_ids`` – shape ``(batch_size, seq_len)`` integer IDs for each token,
``char_ids`` – shape ``(batch_size, seq_len, word_len)`` character‑level IDs,
``target_ids`` – shifted version of ``word_ids`` used as language‑model targets.
Notes
-----
* The class performs a non‑overlapping stride split (`self.token_buffer =
self.token_buffer[self.seq_len:]`) which can be changed to an overlapping stride
if desired for better coverage.
* All returned arrays are JAX ``jnp`` objects.
"""
def __init__(self, ds, vocab, char_to_id, seq_len, word_len, batch_size, shuffle_buffer=2048, seed=0):
self.ds = ds
self.vocab = vocab
self.char_to_id = char_to_id
self.seq_len = seq_len
self.word_len = word_len
self.batch_size = batch_size
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.token_buffer = []
def _get_text_field(self, row):
for key in ["text", "sentence", "review", "content"]:
if key in row:
return row[key]
raise KeyError(f"No text field found in row keys: {list(row.keys())}")
def _encode_window(self, toks):
word_ids = [self.vocab.get(w, self.vocab.get(UNK_WORD)) for w in toks]
char_ids = np.full(
(self.seq_len, self.word_len),
self.char_to_id[PAD_CHAR],
dtype=np.int32,
)
for i, w in enumerate(toks):
cids = [self.char_to_id.get(c, self.char_to_id[UNK_CHAR]) for c in w[:self.word_len]]
char_ids[i, :len(cids)] = cids
return {
"word_ids": np.array(word_ids, dtype=np.int32),
"char_ids": char_ids,
}
def _shuffle_buffer_iter(self, ds):
buf = []
for row in ds:
buf.append(row)
if len(buf) >= self.shuffle_buffer:
random.shuffle(buf)
while buf:
yield buf.pop()
random.shuffle(buf)
while buf:
yield buf.pop()
def __iter__(self):
self.token_buffer = []
batch_words, batch_chars, batch_targets = [], [], []
for row in self._shuffle_buffer_iter(self.ds):
text = self._get_text_field(row)
new_toks = text.split()
self.token_buffer.extend(new_toks)
while len(self.token_buffer) >= self.seq_len + 1:
x_toks = self.token_buffer[:self.seq_len]
y_toks = self.token_buffer[1:self.seq_len + 1]
# Non-overlapping stride (for speed) ideally you want overlapping
self.token_buffer = self.token_buffer[self.seq_len:]
x_enc = self._encode_window(x_toks)
y_ids = np.array(
[self.vocab.get(w, self.vocab.get(UNK_WORD)) for w in y_toks],
dtype=np.int32,
)
batch_words.append(x_enc["word_ids"])
batch_chars.append(x_enc["char_ids"])
batch_targets.append(y_ids)
if len(batch_words) == self.batch_size:
yield {
"word_ids": jnp.stack(batch_words),
"char_ids": jnp.stack(batch_chars),
"target_ids": jnp.stack(batch_targets),
}
batch_words, batch_chars, batch_targets = [], [], []Build Vocabulary
If you have also opted to save the toy dataset to disk, then the next step is to load it and build the word vocabulary. Note that we have already built the character vocabulary above.
Code
hf_ds = load_from_disk("c4_train")
vocab, _ = build_word_vocab(hf_ds)Define Architecture Components
We went over the bulk of the architecture up above with the diagrams and the LSTM overview, however, I wanted to focus on a component within the character-level CNN called the highway.
The CNN applies a set of parallel one‑dimensional convolutions with varying kernel widths over the character sequences. Each filter’s output is max‑pooled across the temporal dimension, and the pooled features from all filters are concatenated to form a fixed‑size vector. This vector is then fed into a highway network, which learns a gated mixture of the raw CNN output and a transformed version of it, yielding the final character‑level embedding.
More specifically, we have our transformed piece of the input vector \(x\) defined as
\[ trans = ReLU(W_{t}x + b_{t}) \]
and learnable gates defined as
\[ gate = \sigma(W_{g}x + b_{g}) \]
and we blend the two together with the following
\[ highway(x) = gate(x) \odot trans + (1 - gate(x)) \odot x \]
You’ll notice that it is possible for gates to be zero implying that the raw CNN outputs flow through unchanged. The highway layers add robustness to vanisihing/exploding gradients and to noisy CNN outputs. They also allow learning more diverse representations because the network can keep low-level, high-frequency patterns or mix them into higher-level, smoothed features.
class Highway(nnx.Module):
"""
A simple ``Highway`` network module
The layer implements the classic gating mechanism that controls how much of the
transformed input should be let through versus left for a direct residual shortcut.
Parameters
----------
dim : int
Dimensionality of the input and output vectors.
rngs : nnx.Rngs
Random number generators.
Returns
-------
nnx.Module
A callable module whose ``__call__(self, x)`` method returns
.. code-block:: python
H * T + x * (1 - T)
where
* ``proj(x)`` --> ``H = relu(proj(x))`` is the transformed (candidate) signal,
* ``trans(x)`` --> ``T = sigmoid(trans(x))`` is a gate in ``[0, 1]``,
* ``x * (1 - T)`` passes through the original input weighted by the complement of
the gate.
"""
def __init__(self, dim, *, rngs: nnx.Rngs):
self.proj = nnx.Linear(dim, dim, rngs=rngs)
self.trans = nnx.Linear(dim, dim, rngs=rngs)
def __call__(self, x):
H = jax.nn.relu(self.proj(x))
T = jax.nn.sigmoid(self.trans(x))
return H * T + x * (1 - T)
class CharCNN(nnx.Module):
"""
Character‑level convolutional encoder that produces a dense
representation for each token in a sequence.
The module first embeds each input character, applies a set of 1‑D
convolutions across the character dimension, performs a global
max‑pool, concatenates the filter responses, and optionally
processes them through Highway layers and a final projection
layer.
Parameters
----------
vocab_size : int
Size of the character vocabulary. Must be at least the number
of distinct characters (plus any padding/unknown tokens) used
in the input data.
char_dim : int
Dimensionality of the learned character embeddings.
filters : Sequence[tuple[int, int]]
A list of `(width, out_channels)` pairs that specify the
kernel width and number of output channels for each 1‑D
convolution.
highway_layers : int
Number of Highway layers applied after the convolution
stack.
proj_dim : int | None, default ``None``
If given, a final linear projection is applied to the
concatenated convolution+highway features to reduce (or
expand) the dimensionality to ``proj_dim``. If ``None``,
the output dimensionality equals the total number of
convolution channels.
rngs : nnx.Rngs
Random number generator.
Notes
-----
* **Input shape**: ``char_ids`` must have shape
``[B, T, W]`` where ``B`` is the batch size, ``T`` the number of
tokens per sequence, and ``W`` the maximum number of characters
per token (words are padded to this length).
* Each convolution operates on the character dimension.
After convolution it is followed by a ReLU activation and a
channel‑wise global max‑pool over the remaining character
positions, resulting in a single scalar per filter channel.
* If ``proj_dim`` is ``None`` the output shape will be
``[B, T, total_filters]`` where
``total_filters`` is the sum of all ``out_channels`` across
the filters. If a projection is used the output shape is
``[B, T, proj_dim]``.
Returns
-------
jnp.ndarray
Character‑derived embedding of shape ``[B, T, proj_dim]`` if a
projection layer is supplied, otherwise ``[B, T, total_filters]``.
"""
def __init__(self, vocab_size, char_dim, filters, highway_layers, proj_dim=None, *, rngs: nnx.Rngs):
self.emb = nnx.Embed(vocab_size, char_dim, rngs=rngs)
self.convs = nnx.List([nnx.Conv(
in_features=char_dim,
out_features=out_channels,
kernel_size=(width,),
feature_group_count=1,
use_bias=True,
rngs=rngs
)
for (width, out_channels) in filters])
# highway layers
total_filters = sum(out for _, out in filters)
self.highways = nnx.List([Highway(total_filters, rngs=rngs) for _ in range(highway_layers)])
self.proj_dim = proj_dim
if proj_dim is not None:
self.proj = nnx.Linear(total_filters, proj_dim, rngs=rngs)
def __call__(self, char_ids):
B, T, W = char_ids.shape
# embed -> [B, T, W, D]
x = self.emb(char_ids)
# apply convs along the word-length axis: we first reshape to merge batch/time
x_flat = x.reshape((B*T, W, x.shape[-1])) # [B*T, W, D]
conv_outs = []
for conv in self.convs:
y = conv(x_flat) # [B*T, new_len, out_ch]
y = jax.nn.relu(y)
y = jnp.max(y, axis=1) # max-pool over positions
conv_outs.append(y)
x_cat = jnp.concatenate(conv_outs, axis=-1) # [B*T, total_filters]
# highways
for h in self.highways:
x_cat = h(x_cat)
if self.proj_dim is not None:
x_cat = self.proj(x_cat) # [B*T, proj_dim]
return x_cat.reshape((B, T, -1)) # [B, T, embed_dim]Code
class LSTMCell(nnx.Module):
"""
A single LSTM cell implemented in ``flax.nnx`` that supports
dropout on both the input and the recurrent connection.
Parameters
----------
input_dim : int
Dimensionality of the input vector ``x_t``
hidden_dim : int
Number of LSTM hidden units; defines the size of the hidden
state ``h`` and cell state ``c``.
dropout : float, default 0.0
Dropout probability applied to both the input vector and the
recurrent hidden state when ``deterministic=False``.
rngs : nnx.Rngs | None, default ``None``
Random number generators.
Notes
-----
* Internal weight maps:
``Wx : input_dim: 4 x hidden_dim``
``Wh : hidden_dim: 4 x hidden_dim``
The four output channels correspond respectively to the
*input*, *forget*, *output*, and *candidate* gates.
* Dropout is applied according to the standard
“inverted” scheme (`mask / (1‑p)`), where the mask is drawn
from a Bernoulli distribution with probability `1‑dropout`.
Two independent masks are used: one for the current input
``x_t`` and one for the recurrent hidden state ``h``. A
distinct RNG key must be passed via ``jax_rng`` when
``deterministic=False``.
* The cell state is a tuple ``(h, c)`` where each component has
shape ``[batch, hidden_dim]``. The method returns a tuple
``((h_new, c_new), h_new)``; the outer tuple contains the
updated cell state, and the inner ``h_new`` is the output
vector that can be consumed by ``nnx.scan`` or another
sequence wrapper.
Returns
-------
tuple
* ``((h_new, c_new), h_new)``
where ``h_new`` and ``c_new`` are arrays of shape
``[batch, hidden_dim]`` representing the updated
hidden and cell states, respectively. The second
``h_new`` in the outer tuple is the output of this
cell and matches the batch dimension of x_t``.
"""
def __init__(self, input_dim, hidden_dim, dropout=0.0, *, rngs=None):
self.Wx = nnx.Linear(input_dim, 4 * hidden_dim, rngs=rngs)
self.Wh = nnx.Linear(hidden_dim, 4 * hidden_dim, rngs=rngs)
self.hidden_dim = hidden_dim
self.dropout = dropout
def __call__(self, carry, x_t, deterministic=True, jax_rng=None):
h, c = carry
if not deterministic:
assert jax_rng is not None, "RNG key must be passed for dropout"
rng_inp, rng_rec = jax.random.split(jax_rng)
x_mask = jax.random.bernoulli(rng_inp, 1.0 - self.dropout, x_t.shape)
x_t = x_t * x_mask / (1.0 - self.dropout)
h_mask = jax.random.bernoulli(rng_rec, 1.0 - self.dropout, h.shape)
h = h * h_mask / (1.0 - self.dropout)
gates = self.Wx(x_t) + self.Wh(h)
i, f, o, g = jnp.split(gates, 4, axis=-1)
i = jax.nn.sigmoid(i)
f = jax.nn.sigmoid(f)
o = jax.nn.sigmoid(o)
g = jnp.tanh(g)
c_new = f * c + i * g
h_new = o * jnp.tanh(c_new)
return (h_new, c_new), h_new
class BiLSTMLayer(nnx.Module):
"""
A bidirectional LSTM layer built on top of :class:`LSTMCell`.
Each input sequence is processed in two directions:
* a forward LSTM that runs from the first to the last token,
* a backward LSTM that runs from the last to the first token.
The outputs of the two directions are returned separately and also
concatenated along the feature dimension.
Parameters
----------
in_dim : int
Dimensionality of input tokens.
hidden_dim : int
Size of the hidden state in each direction (`h` and `c`).
dropout : float, default 0.0
Dropout probability applied inside each :class:`LSTMCell`. The
same dropout probability is used for both forward and backward
streams.
rngs : nnx.Rngs | None, default ``None``
Random number generators
Notes
-----
* **RNG handling** -
When ``deterministic=False`` a single ``jax_rng`` is split once
into two keys (for forward and backward). These keys are then
split further into a per‑time‑step key that is passed to
:class:`LSTMCell`. For deterministic execution ``rngs_fwd`` and
``rngs_bwd`` are lists of ``None``.
* **State management** -
Each direction starts from a zero initial hidden and cell state of
shape ``(B, hidden_dim)``.
* **Output dimensions** -
For an input of shape ``(B, T, D)`` the three returned tensors have
shapes:
* ``hs_fwd``: ``(B, T, hidden_dim)``
* ``hs_bwd``: ``(B, T, hidden_dim)``
* ``hs_concat``: ``(B, T, 2 * hidden_dim)``
Returns
-------
tuple
``(hs_fwd, hs_bwd, hs_concat)``
* ``hs_fwd`` - hidden states produced by the forward LSTM,
* ``hs_bwd`` - hidden states produced by the backward LSTM,
* ``hs_concat`` - concatenation of ``hs_fwd`` and ``hs_bwd``.
"""
def __init__(self, in_dim, hidden_dim, dropout=0.0, *, rngs=None):
self.fwd = LSTMCell(in_dim, hidden_dim, dropout=dropout, rngs=rngs)
self.bwd = LSTMCell(in_dim, hidden_dim, dropout=dropout, rngs=rngs)
def __call__(self, inputs, deterministic=True, jax_rng=None):
B, T, D = inputs.shape
h0 = jnp.zeros((B, self.fwd.hidden_dim))
c0 = jnp.zeros((B, self.fwd.hidden_dim))
if not deterministic and jax_rng is not None:
# Pre-split RNGs for each time step
rng_fwd, rng_bwd = jax.random.split(jax_rng)
rngs_fwd = jax.random.split(rng_fwd, T)
rngs_bwd = jax.random.split(rng_bwd, T)
else:
rngs_fwd = [None] * T
rngs_bwd = [None] * T
def fwd_scan(carry, x_and_rng):
x_t, rng_t = x_and_rng
return self.fwd(carry, x_t, deterministic=deterministic, jax_rng=rng_t)
def bwd_scan(carry, x_and_rng):
x_t, rng_t = x_and_rng
return self.bwd(carry, x_t, deterministic=deterministic, jax_rng=rng_t)
xs_fwd = (inputs.swapaxes(0, 1), rngs_fwd)
xs_bwd = (jnp.flip(inputs, axis=1).swapaxes(0, 1), rngs_bwd)
_, hs_fwd = jax.lax.scan(fwd_scan, (h0, c0), xs_fwd)
hs_fwd = hs_fwd.swapaxes(0, 1)
_, hs_bwd_rev = jax.lax.scan(bwd_scan, (h0, c0), xs_bwd)
hs_bwd = jnp.flip(hs_bwd_rev.swapaxes(0, 1), axis=1)
return hs_fwd, hs_bwd, jnp.concatenate([hs_fwd, hs_bwd], axis=-1)
class StackedBiLSTM(nnx.Module):
"""
A stack of bidirectional LSTM layers implemented with :class:`BiLSTMLayer`.
Each layer receives the concatenated hidden states of its predecessors
(``h_fwd`` || ``h_bwd``) as input.
Parameters
----------
input_dim : int
Dimensionality of the raw token embeddings fed to the first
:class:`BiLSTMLayer`. The dimensionality of all following layers
will be `2 * hidden_dim`.
hidden_dim : int
Hidden size of each unidirectional LSTM within every
:class:`BiLSTMLayer`. The actual output of a layer
has shape ``(B, T, 2 * hidden_dim)``.
num_layers : int
Number of stacked bidirectional layers.
dropout : float, default 0.0
Dropout probability applied inside each :class:`LSTMCell`.
rngs : nnx.Rngs | None, default ``None``
Random number generators
Returns
-------
1. ``outs`` - A list containing the *raw input* followed by the
concatenated output of each layer. Hence ``len(outs) ==
num_layers + 1`` and the last element has shape
``(B, T, 2 * hidden_dim)``.
2. ``fwd_states`` - A list of the forward hidden states from each
layer (shape ``(B, T, hidden_dim)``).
3. ``bwd_states`` - A list of the backward hidden states from each
layer (shape ``(B, T, hidden_dim)``).
"""
def __init__(self, input_dim, hidden_dim, num_layers, dropout=0.0, *, rngs=None):
self.layers = nnx.List([
BiLSTMLayer(input_dim if i == 0 else 2*hidden_dim, hidden_dim, dropout=dropout, rngs=rngs)
for i in range(num_layers)
])
def __call__(self, x, deterministic=False, jax_rng=None):
outs = [x]
fwd_states, bwd_states = [], []
if not deterministic and jax_rng is not None:
rngs = jax.random.split(jax_rng, len(self.layers))
else:
rngs = [None] * len(self.layers)
for layer, r in zip(self.layers, rngs):
fwd, bwd, x = layer(x, deterministic=deterministic, jax_rng=r)
fwd_states.append(fwd)
bwd_states.append(bwd)
outs.append(x)
return outs, fwd_states, bwd_states
class LMHead(nnx.Module):
"""
Language‑model output head that projects hidden states to logits over a
target vocabulary.
The module consists of a single linear transformation that maps the
last hidden dimension of the model to a vector of size ``vocab_size``.
Parameters
----------
hidden_dim : int
Dimensionality of the input hidden states ``h``.
vocab_size : int
Size of the target vocabulary (number of output logits).
rngs : nnx.Rngs
Random number generators
Returns
-------
jax.numpy.ndarray
Logits of shape ``(B, T, vocab_size)``
"""
def __init__(self, hidden_dim, vocab_size, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(hidden_dim, vocab_size, rngs=rngs)
def __call__(self, h):
return self.linear(h) # [B,T,V]We define our ELMo model by initializing and chaining together the subcomponents:
- Character Level CNN
- Bidirectional LSTM
- Language Modeling Heads
In addition, we also initialize ELMo’s scalar mix parameters that are used to adapt the embeddings during fine-tuning on downstream specific tasks, and we define model regularization in the form of dropout. Our ELMo model has three layers of dropout, at the input layer, the LSTM layer, and the output layer.
In our implementation you’ll notice we project outputs into a common dimension to ensure the dimensional correctness for matrix multiplication, and you’ll also notice that the method forward_embeddings is only used during fine-tuning to tune the scalar mix parameters on downstream specific tasks.
Output dropout regularizes language-model training, not downstream embeddings. Scalar-mixed embeddings themselves are not dropout-regularized in this implementation.
class ElmoModel(nnx.Module):
"""
The model implements the core components of the original ELMo
architecture: a character‑level CNN that produces sub‑word
representations, a stack of bidirectional LSTMs that encode the
sentence, and forward/backward language‑model heads. A
*scalar‑mix* (parameterised by a softmax over learnable weights)
combines the character‑CNN output and every LSTM layer into a
fixed‑dimensional semantic vector (`common_dim`). This vector can
be used as contextualised word embeddings downstream.
Parameters
----------
char_vocab_size : int
Vocabulary size for character indices.
char_dim : int
Size of the character embedding vectors.
filters : Sequence[Tuple[int, int]]
List of ``(num_filters, filter_width)`` tuples that define the
convolutional channels in the character CNN.
highway_layers : int
Number of highway network layers in the character CNN.
proj_dim : int
Dimensionality of the output of the projection layer that comes
after the character CNN.
common_dim : int
Dimensionality of the final ELMo embedding.
hidden_dim : int
Hidden state size of each BiLSTM cell.
num_layers : int
Number of stacked BiLSTM layers.
word_vocab_size : int
Size of the vocabulary for the forward and backward language‑model
heads.
input_dropout : float, default 0.1
Dropout probability applied to the output of the character CNN
during training.
lstm_dropout : float, default 0.1
Dropout probability applied within each BiLSTM layer during
training.
output_dropout : float, default 0.1
Dropout probability applied to the top‑layer LSTM states before
they are fed to the language‑model heads.
rngs : nnx.Rngs
JAX random number generator state used to initialise parameters.
Notes
-----
* The character‑CNN (`self.char_cnn`) maps the raw character IDs to
a vector of dimensionality ``proj_dim``. This vector is then
projected to the common embedding space (`common_dim`) by a
linear layer.
* Each BiLSTM layer produces a forward state of shape
`(batch, seq_len, hidden_dim)` and a backward state of the
same shape. The states of a layer are concatenated along the
feature dimension and projected to ``common_dim``.
* The scalar mix treats the character‑CNN output as layer 0 and
each BiLSTM layer as a subsequent layer. The weight vector
(`self.scalar_weights`) is soft‑maxed so that the weights sum
to 1. The result is scaled by the learnable `gamma` parameter.
* During evaluation (``deterministic=True``) drop‑outs are disabled
and the same provided RNG is used to keep the computation
deterministic.
Methods
-------
forward_backbone(char_ids, jax_rng, deterministic=True)
Compute the character embeddings, forward and backward LSTM states.
Returns
-------
char_embs : jnp.ndarray
The raw output of the character CNN, shape
`(batch, seq_len, proj_dim)`.
fwd_states : List[jnp.ndarray]
Forward LSTM states for each of the ``num_layers`` layers,
each of shape `(batch, seq_len, hidden_dim)`.
bwd_states : List[jnp.ndarray]
Backward LSTM states for each layer, each of shape
`(batch, seq_len, hidden_dim)`.
forward_logits(char_ids, jax_rng, deterministic=True)
Return the forward and backward language‑model logits together
with the intermediate representations.
Returns
-------
fwd_logits : jnp.ndarray
Forward language‑model logits, shape
`(batch, seq_len, word_vocab_size)`.
bwd_logits : jnp.ndarray
Backward language‑model logits (time‑reversed), same shape as
``fwd_logits``.
char_embs : jnp.ndarray
Raw character‑CNN embeddings (as in ``forward_backbone``).
fwd_states : List[jnp.ndarray]
Forward LSTM states.
bwd_states : List[jnp.ndarray]
Backward LSTM states.
forward_embeddings(char_embs, fwd_states, bwd_states)
Produce the final contextualised embedding vector.
Returns
-------
x : jnp.ndarray
Contextualised ELMo embedding of shape
`(batch, seq_len, common_dim)`. It is a weighted sum of the
projected character‑CNN output and each concatenated
forward/backward LSTM layer, scaled by `gamma`.
"""
def __init__(self, char_vocab_size, char_dim, filters, highway_layers,
proj_dim, common_dim, hidden_dim, num_layers, word_vocab_size,
input_dropout=0.1, lstm_dropout=0.1, output_dropout=0.1, *,
rngs: nnx.Rngs):
# Submodules
self.char_cnn = CharCNN(char_vocab_size, char_dim, filters, highway_layers, proj_dim=proj_dim, rngs=rngs)
self.bilstm = StackedBiLSTM(proj_dim, hidden_dim, num_layers, dropout=lstm_dropout, rngs=rngs)
self.fwd_head = LMHead(hidden_dim, word_vocab_size, rngs=rngs)
self.bwd_head = LMHead(hidden_dim, word_vocab_size, rngs=rngs)
# Scalar mix for ELMo embeddings
self.common_dim = common_dim
self.scalar_weights = nnx.Param(jnp.zeros(num_layers + 1))
self.gamma = nnx.Param(jnp.array(1.0))
# Projection layers to common_dim
self.layer_projections = nnx.List()
# CharCNN output to common dim
self.layer_projections.append(
nnx.Linear(proj_dim, common_dim, rngs=rngs)
)
# BiLSTM layers (2 * hidden_dim) to common_dim
for _ in range(num_layers):
self.layer_projections.append(
nnx.Linear(2 * hidden_dim, common_dim, rngs=rngs)
)
assert len(self.layer_projections) == len(self.scalar_weights.value)
# Dropout layers
self.input_dropout = nnx.Dropout(rate=input_dropout, rngs=rngs)
self.output_dropout = nnx.Dropout(rate=output_dropout, rngs=rngs)
def forward_backbone(self, char_ids, jax_rng, deterministic: bool = True):
char_embs = self.char_cnn(char_ids)
if not deterministic:
assert jax_rng is not None
rng_in, rng_lstm = jax.random.split(jax_rng)
x = self.input_dropout(char_embs, rngs=rng_in)
else:
x = char_embs
rng_lstm = jax_rng
_, fwd_states, bwd_states = self.bilstm(x, deterministic=deterministic, jax_rng=rng_lstm)
return char_embs, fwd_states, bwd_states
def forward_logits(self, char_ids, jax_rng, deterministic: bool = True):
char_embs, fwd_states, bwd_states = self.forward_backbone(
char_ids, deterministic=deterministic, jax_rng=jax_rng
)
top_fwd = fwd_states[-1]
top_bwd = bwd_states[-1]
top_fwd = self.output_dropout(top_fwd)
top_bwd = self.output_dropout(top_bwd)
fwd_logits = self.fwd_head(top_fwd)
bwd_logits = jnp.flip(
self.bwd_head(jnp.flip(top_bwd, axis=1)), axis=1
)
return fwd_logits, bwd_logits, char_embs, fwd_states, bwd_states
def forward_embeddings(self, char_embs, fwd_states, bwd_states):
layers = [char_embs] + [
jnp.concatenate([fwd, bwd], axis=-1)
for fwd, bwd in zip(fwd_states, bwd_states)
]
w = jax.nn.softmax(self.scalar_weights.value)
projected = [
proj(layer) for proj, layer in zip(self.layer_projections, layers)
]
x = sum(w_i * p for w_i, p in zip(w, projected))
x = self.gamma.value * x
return xLoss Function
We have already discussed the learning objective, below is an implementation of a masked cross entropy loss function that masks padding from the loss computation.
Code
def masked_cross_entropy(logits, targets, pad_id=0):
"""
Computes the average cross‑entropy loss for a batch while ignoring
padding tokens.
The function applies a *mask* to the per‑token loss so that any token
whose target index equals ``pad_id`` is dropped from the loss
calculation.
Parameters
----------
logits : jnp.ndarray
Logits produced by the model. Expected shape
``(batch, seq_len, vocab_size)``.
targets : jnp.ndarray
Ground‑truth token indices. Expected shape ``(batch, seq_len)``
with integer values in ``[0, vocab_size)``. Positions that
contain ``pad_id`` are treated as padding.
pad_id : int, default=0
The integer value used to mark padding positions in ``targets``.
Returns
-------
loss : float
The mean cross‑entropy loss over all non‑padding tokens in the
batch. The denominator is the total number of non‑padding
tokens plus a small constant ``1e-12`` to avoid division by
zero.
"""
vocab_size = logits.shape[-1]
log_probs = jax.nn.log_softmax(logits, axis=-1)
targets_onehot = jax.nn.one_hot(targets, vocab_size)
per_token_loss = -jnp.sum(targets_onehot * log_probs, axis=-1)
mask = (targets != pad_id).astype(jnp.float32)
return jnp.sum(per_token_loss * mask) / (jnp.sum(mask) + 1e-12)Define Training Loop
Next, we define the training loop, which encompasses the standard steps performed for each batch: executing the forward pass, computing gradients, and updating model parameters. This requires initializing the model and optimizer, configuring checkpointing to persist training state, and iterating over the training step for the desired number of epochs. These components follow conventional neural network training practice and do not introduce any ELMo-specific complexity.
There are, however, two practical considerations worth noting:
- We intentionally select hyperparameters corresponding to a reduced ELMo configuration, as the full-scale model is computationally infeasible on the available hardware.
- We limit pre-training to five epochs for the same reason.
Despite these constraints, even this abbreviated pre-training regimen yields clearly observable improvements on downstream tasks, as demonstrated in subsequent sections.
Code
@jax.jit
def train_step(optimizer, batch, jax_rng):
"""
Performs a single optimisation update on the ELMo‐style model.
The function runs a forward pass of the model to obtain forward and
backward language‑model logits, computes the masked cross‑entropy
loss for each direction, sums the two losses, back‑propagates the
gradients, and finally applies the optimiser's update rule.
Parameters
----------
optimizer : nnx.optim.Optimizer
An *nnx.optim.Optimizer* that owns the model to be trained.
The optimiser must expose a ``model`` attribute and provide an
``update`` method that accepts the gradient dictionary.
batch : Mapping[str, jnp.ndarray]
A batch of training data mapping the following keys to
integer arrays:
* ``"char_ids"``: shape ``(batch, seq_len, word_len)``
containing character indices for each token.
* ``"target_ids"``: shape ``(batch, seq_len)``
containing the target token indices for the language‑model head.
jax_rng : jax.random.PRNGKey
Random number generator used for dropout and other stochastic
components of the model.
Returns
-------
optimizer : nnx.optim.Optimizer
The optimiser after the update. It holds the freshly
updated model parameters.
loss : float
The scalar loss value that was optimised. It is the sum of the
forward and backward masked cross‑entropy losses.
Notes
-----
* The optimiser should store the model in the ``optimizer.model``
attribute. After the update, modifications to ``optimizer.model``
reflect the new parameters.
"""
char_ids = batch["char_ids"]
targets = batch["target_ids"]
def loss_fn(model):
fwd_logits, bwd_logits, _, _, _ = model.forward_logits(
char_ids, deterministic=False, jax_rng=jax_rng
)
fwd_loss = masked_cross_entropy(
fwd_logits[:, :-1, :], targets[:, 1:]
)
bwd_loss = masked_cross_entropy(
bwd_logits[:, 1:, :], targets[:, :-1]
)
return fwd_loss + bwd_loss
loss = loss_fn(optimizer.model)
grads = nnx.grad(loss_fn)(optimizer.model)
optimizer.update(grads)
return optimizer, lossCode
# RNG setup
rngs = nnx.Rngs(0)
# Hyperparameters
char_vocab_size = len(char_to_id)
char_dim = 16
filters = [
(1, 64),
(2, 128),
(3, 256),
(4, 256),
(5, 256),
(6, 256),
]
highway_layers = 2
proj_dim = 512
hidden_dim = 512
num_layers = 2
word_vocab_size = len(vocab)
common_dim = 512
# Model instantiation
model = ElmoModel(
char_vocab_size=char_vocab_size,
char_dim=char_dim,
filters=filters,
highway_layers=highway_layers,
proj_dim=proj_dim,
common_dim=common_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
word_vocab_size=word_vocab_size,
input_dropout=0.1,
lstm_dropout=0.3,
output_dropout=0.1,
rngs=rngs
)
# Optimizer instantiation
tx = optax.adamw(1e-3, weight_decay=1e-2)
optimizer = nnx.ModelAndOptimizer(model, tx)Code
class ELMoTrainer:
"""
A training wrapper for an ELMo bidirectional language model.
It handles epoch‑wise training, validation
loss computation, and an early‑stopping strategy.
Parameters
----------
optimizer : nnx.optim.Optimizer
The optimiser that owns the model to be trained. It must expose a
``model`` attribute (the trainable model) and an ``update`` method.
patience : int, optional (default=3)
Number of consecutive validation epochs without improvement on the
loss after which training is stopped early and the best model
parameters are restored.
rng_seed : int, optional (default=0)
Seed for the JAX random number generator.
Attributes
----------
optimizer : nnx.optim.Optimizer
The optimiser being used.
patience : int
See ``patience`` above.
best_loss : float
Best validation loss observed so far. Initialized to ``∞``.
wait : int
Number of epochs since the last improvement.
best_params : nnx.State | None
The model parameters that produced the best validation loss.
Stored as a JAX ``Mutable`` state so that they can be copied back
into ``optimizer.model`` on early‑stopping.
jax_rng : jax.random.PRNGKey
Current random key.
Methods
-------
train_epoch(train_loader)
Runs one epoch of training on ``train_loader``.
validate(model, val_loader)
Computes the average loss over ``val_loader``.
validate_and_stop(val_loader)
Performs validation, logs results and checks the early‑stopping
criterion. Returns ``True`` if training should stop.
"""
def __init__(self, optimizer, patience=3, rng_seed=0):
self.optimizer = optimizer
self.patience = patience
self.best_loss = float("inf")
self.wait = 0
self.best_params = None
self.jax_rng = jax.random.PRNGKey(rng_seed)
def train_epoch(self, train_loader):
total_loss = 0.0
n = 0
for batch in train_loader:
self.jax_rng, subkey = jax.random.split(self.jax_rng)
self.optimizer, loss = train_step(self.optimizer, batch, jax_rng=subkey)
total_loss += float(loss)
n += 1
return total_loss / max(1, n)
def validate(self, model, val_loader):
total_loss = 0.0
n = 0
for batch in val_loader:
targets = batch["target_ids"]
fwd_logits, bwd_logits, _, _, _ = model.forward_logits(
batch["char_ids"], jax_rng=self.jax_rng
)
# Forward predicts t+1
fwd_loss = masked_cross_entropy(fwd_logits[:, :-1, :], targets[:, 1:])
# Backward predicts t-1
bwd_loss = masked_cross_entropy(bwd_logits[:, 1:, :], targets[:, :-1])
loss = fwd_loss + bwd_loss
total_loss += float(loss)
n += 1
mean_loss = total_loss / max(1, n)
return mean_loss
def validate_and_stop(self, val_loader):
val_loss = self.validate(self.optimizer.model, val_loader)
print(f" val_loss={val_loss:.4f}")
if val_loss < self.best_loss:
self.best_loss = val_loss
self.best_params = nnx.state(self.optimizer.model)
self.wait = 0
print(" New best model saved.")
return False # continue training
self.wait += 1
print(f" No improvement ({self.wait}/{self.patience})")
if self.wait >= self.patience:
print("Early stopping triggered!")
nnx.update(self.optimizer.model, self.best_params) # restore best params
return True # stop training
return FalseCode
# Where to save model checkpoints
ckpt_dir = "./checkpoints/elmo/state/"
checkpointer = ocp.StandardCheckpointer()When creating your checkpoints, ensure you are checkpointing the state from the model that has the updated weights. In our case the one initialized in our trainer class tied to our optimizer and not the initialized model from above as those weights will not be updated.
Code
trainer = ELMoTrainer(optimizer, patience=100) # No need to early stop with pre-training
train_ds = load_from_disk("c4_train")
val_ds = load_from_disk("c4_val")
epochs = 5
for epoch in range(epochs):
print(f"Epoch {epoch + 1}")
# reset streaming ds
train_loader = StreamingTextDataLoader(train_ds, vocab, char_to_id,
seq_len=128, word_len=50,
batch_size=20, shuffle_buffer=2048)
train_loss = trainer.train_epoch(train_loader)
if (epoch + 1) % 2 == 0:
# reset streaming ds
val_loader = StreamingTextDataLoader(val_ds, vocab, char_to_id,
seq_len=128, word_len=50,
batch_size=20, shuffle_buffer=2048)
stop = trainer.validate_and_stop(val_loader)
# Save checkpoint each epoch
_, state = nnx.split(trainer.optimizer.model) # Make sure you use the model with the updated weights not the initialized model from above
checkpointer.save(
os.path.abspath(
os.path.join(ckpt_dir, f"epoch{ epoch + 1 }")
),
state
)
if stop:
breakEvaluate Learned Embeddings on Downstream Task
Okay, now that we have a pre-trained ELMo model on hand we are going to fine-tune it for a text classification task. For this task, we are going to use the Stanford Sentiment Treebank v2 (SST‑2) dataset.
Compare Random Weights to Pretrained Model
We are going to compare how a random weights initialized ELMo model performs in comparison to our breifly pre-trained ELMo model. Below we initialize a random model, very much the same way we initialized a model for pre-training. We also load our pre-trained weights from our saved checkpoint.
Code
# Initialize a random weights model
# RNG setup
rng = jax.random.PRNGKey(0)
rngs = nnx.Rngs(rng)
# Hyperparameters
char_vocab_size = len(char_to_id)
char_dim = 16
filters = [
(1, 64),
(2, 128),
(3, 256),
(4, 256),
(5, 256),
(6, 256),
]
highway_layers = 2
proj_dim = 512
hidden_dim = 512
num_layers = 2
word_vocab_size = len(vocab)
common_dim = 512
# Model instantiation
model_random = ElmoModel(
char_vocab_size=char_vocab_size,
char_dim=char_dim,
filters=filters,
highway_layers=highway_layers,
proj_dim=proj_dim,
common_dim=common_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
word_vocab_size=word_vocab_size,
input_dropout=0.1,
lstm_dropout=0.3,
output_dropout=0.1,
rngs=rngs
)Code
# Load checkpointed model
# Construct an abstract version of the model (this is an empty scaffold so memory utilization is minimal)
abstract_model = nnx.eval_shape(
lambda: ElmoModel(
char_vocab_size=char_vocab_size,
char_dim=char_dim,
filters=filters,
highway_layers=highway_layers,
proj_dim=proj_dim,
common_dim=common_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
word_vocab_size=word_vocab_size,
input_dropout=0.1,
lstm_dropout=0.3,
output_dropout=0.1,
rngs=nnx.Rngs(0)
)
)
# Split to get graphdef and an abstract state
graphdef, abstract_state = nnx.split(abstract_model)
# Restore into that abstract state
ckpt_dir = "./checkpoints/elmo/state/"
epoch = 5
checkpointer = ocp.StandardCheckpointer()
restored_state = checkpointer.restore(
os.path.abspath(os.path.join(ckpt_dir, f"epoch{epoch}"))
)
# Merge to produce a real model with pretrained weights
model_trained = nnx.merge(graphdef, restored_state)WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
If you used a GPU for pre-training and you decide that you want to load your pre-trained weights onto a different device. You will need to map the state onto the new device. Below is a code cell that shows you how to do that.
Code
# To load the model that was trained using GPU onto a CPU only device use this:
# Ensure abstract_state is placed on the current local devices
cpu_device = jax.devices('cpu')[0]
sharding = jax.sharding.SingleDeviceSharding(cpu_device)
# Construct an abstract version of the model
abstract_model = nnx.eval_shape(
lambda: ElmoModel(
char_vocab_size=char_vocab_size,
char_dim=char_dim,
filters=filters,
highway_layers=highway_layers,
proj_dim=proj_dim,
common_dim=common_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
word_vocab_size=word_vocab_size,
input_dropout=0.1,
lstm_dropout=0.3,
output_dropout=0.1,
rngs=nnx.Rngs(0)
)
)
# Split to get graphdef and an abstract state
graphdef, abstract_state = nnx.split(abstract_model)
# Map the sharding onto your abstract state leaves
abstract_state = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding),
abstract_state
)
# Initialize and restore
ckpt_dir = "./checkpoints/elmo/state/"
epoch = 5
checkpoint_path = os.path.abspath(os.path.join(ckpt_dir, f"epoch{epoch}"))
checkpointer = ocp.StandardCheckpointer()
# Pass abstract_state
restored_state = checkpointer.restore(
checkpoint_path,
abstract_state
)
# Load into the model
model_trained = nnx.merge(graphdef, restored_state)Classifier Architecture
Okey-dokes, let’s put together a classifier head to leverage our ELMo embeddings on our downstream text classification task. The sub-components for our network will be:
- The ELMo model as the backbone
- An attention pooling layer with explicit padding masking to collapse our seqeunces into a single representation
- A two-layer multilayer perceptron with dropout and ReLU activation yeilding our logits
It is worth noting that the attention pooling layer here is not the same as attention in transformer models. This layer simply learns a scalar scoring function over token embeddings.
class AttnPool(nnx.Module):
"""
Attention‑based pooling layer that collapses a sequence of vectors into a
single representation using a learnable attention weight.
Parameters
----------
dim : int
Dimensionality of each input vector (``x.shape[-1]``).
rngs : jax.random.PRNGKey
Random number generator
Returns
-------
jnp.ndarray
A tensor of shape ``(batch_size, dim)`` containing the weighted
sum of the input sequence.
"""
def __init__(self, dim, *, rngs):
self.proj = nnx.Linear(dim, 1, rngs=rngs)
def __call__(self, x, mask=None):
scores = self.proj(x).squeeze(-1)
if mask is not None:
scores = scores + (mask - 1) * 1e9
weights = jax.nn.softmax(scores, axis=1)
return jnp.sum(x * weights[..., None], axis=1)
class ElmoClassifier(nnx.Module):
"""
Sequence classifier that builds on a pre‑trained ELMo backbone.
The network follows the classic ELMo‑to‑text‑classification pipeline:
1. **ELMo backbone** – A shared ELMo `ElmoModel` is used to generate
contextual embeddings for each token (`char_ids`).
2. **Mask‑aware attention pooling** – The token‑wise embeddings are
weighted with an attention mechanism that respects the padding mask.
3. **MLP classifier** – A two‑layer MLP with dropout and ReLU non‑linearity
produces the final logits for *n_classes*.
Parameters
----------
elmo_model : ElmoModel
Pre‑trained ELMo backbone that exposes two forward
stages:
* ``forward_backbone(char_ids, deterministic, jax_rng)``
returns word‑level embeddings and forward/backward LSTM states.
* ``forward_embeddings(char_embs, fwd_states, bwd_states)``
The backbone must expose ``common_dim`` - the dimensionality of the ELMo
embeddings that the classifier consumes.
num_classes : int
Number of target classes for the downstream classification task.
dropout_rate : float, default 0.1
Dropout probability applied before and after the hidden MLP layer.
rngs : nnx.Rngs
Random number generators
Attributes
----------
backbone : ElmoModel
Reference to the ELMo backbone used for feature extraction.
dropout : nnx.Dropout
Dropout layer applied to the pooled representation and to the hidden
MLP output.
attn_pool : AttnPool
Attention pooling head that weights tokens based on the ELMo output.
classifier_hidden : nnx.Linear
First linear layer of the classification MLP.
classifier : nnx.Linear
Final linear layer producing logits of shape ``(B, num_classes)``.
Forward Pass
------------
The module expects a 3‑D integer array of character IDs
``char_ids`` with shape ``(B, T, C)`` where:
* **B** - batch size
* **T** - sequence length (token count for each example)
* **C** - number of character embeddings per token
**Deterministic flag** – When ``deterministic=True`` the dropout
layers are disabled
**Masking** – A binary mask is inferred by checking for rows of all
zeros in ``char_ids`` (treated as padding). The mask is added to the
raw attention scores before softmax, ensuring that padded positions
receive negligible weight.
Returns
-------
logits : jnp.ndarray
Unnormalised class scores with shape ``(B, num_classes)``.
The returned ``logits`` can be fed to a standard cross‑entropy loss
during training.
Notes
-----
* The attention pooling performs **softmax over the sequence dimension**
and uses a large negative constant to mask out padding before softmax
(effectively treating those positions as having negligible weight),
which is a stable and differentiable alternative to masking in the
exponent step.
* The model relies on the ELMo backbone providing *common_dim*‑dimensional
embeddings. If the backbone uses a different dimensionality, the
attributes and the MLP width need adjustment accordingly.
"""
def __init__(self, elmo_model: ElmoModel, num_classes: int, dropout_rate: float = 0.1, *, rngs: nnx.Rngs):
self.backbone = elmo_model
self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
self.attn_pool = AttnPool(self.backbone.common_dim, rngs=rngs)
self.classifier_hidden = nnx.Linear(self.backbone.common_dim, self.backbone.common_dim, rngs=rngs)
self.classifier = nnx.Linear(self.backbone.common_dim, num_classes, rngs=rngs)
def __call__(self, char_ids, deterministic: bool = False, jax_rng=None):
char_embs, fwd_states, bwd_states = self.backbone.forward_backbone(char_ids, deterministic=deterministic, jax_rng=jax_rng)
elmo_embs = self.backbone.forward_embeddings(char_embs, fwd_states, bwd_states)
mask = jnp.logical_not(jnp.all(char_ids == 0, axis=-1)).astype(jnp.float32) # [B, T]
pooled = self.attn_pool(elmo_embs, mask)
x = self.dropout(pooled, deterministic=deterministic)
x = jax.nn.relu(self.classifier_hidden(x))
x = self.dropout(x, deterministic=deterministic)
logits = self.classifier(x)
return logitsPrepare Dataset
We need to build a new data loader. The current one was designed for language‑modeling pre‑training, but our goal is now text classification. Therefore, we must produce character and word indices that operate on whole sentences, not on a sliding window.
Code
hf_ds = load_dataset("glue", "sst2")
train_stream = hf_ds["train"]
val_stream = hf_ds["validation"]
# DataLoader setup
batch_size = 256
seq_len = 64
word_len = 50
def encode_batch(texts, vocab, char_to_id, seq_len, word_len):
"""
Encode a batch of raw text strings into fixed‑size integer tensors.
Words and characters that are unseen in the supplied dictionaries are
replaced with the special unknown token. Sequences longer than the
requested limits are truncated, while shorter ones are padded with the
special padding token.
Parameters
----------
texts : Iterable[str]
A batch of raw text strings. Each
element is split on whitespace to produce a list of words.
vocab : Mapping[str, int]
Word‑to‑ID dictionary. Must contain the special tokens ``"<pad>"``
and ``"<unk>"``;
char_to_id : Mapping[str, int]
Character‑to‑ID dictionary. It must contain ``"<pad>"`` and
``"<unk>"`` for padding and unknown characters respectively.
seq_len : int
Maximum number of words per sentence that will be encoded. All
sentences are truncated to this length or padded with the word
padding token.
word_len : int
Maximum number of characters per word that will be encoded. Words
longer than this length are truncated; shorter words are padded
with the character padding token.
Returns
-------
dict
A dictionary with two keys:
``"word_ids"`` : numpy.ndarray of shape ``(batch_size, seq_len)``
``"char_ids"`` : numpy.ndarray of shape ``(batch_size, seq_len,
word_len)``
* ``word_ids[i, j]`` holds the ID of the *j*-th word in the
*i*-th input string; ``vocab[PAD_WORD]`` if the position is
padded.
* ``char_ids[i, j, k]`` holds the ID of the *k*-th character of
the *j*-th word in the *i*-th input string; ``char_to_id[PAD_CHAR]``
if the position is padded.
"""
PAD_WORD = "<pad>"
UNK_WORD = "<unk>"
PAD_CHAR = "<pad>"
UNK_CHAR = "<unk>"
batch_size = len(texts)
word_ids = np.full((batch_size, seq_len), vocab.get(PAD_WORD), dtype=np.int32)
char_ids = np.full((batch_size, seq_len, word_len), char_to_id[PAD_CHAR], dtype=np.int32)
for i, text in enumerate(texts):
toks = text.split()[:seq_len]
# Encode word IDs
wid = [vocab.get(w, vocab.get(UNK_WORD)) for w in toks]
word_ids[i, :len(wid)] = wid
# Encode char IDs
for j, w in enumerate(toks):
cids = [char_to_id.get(c, char_to_id[UNK_CHAR]) for c in w[:word_len]]
char_ids[i, j, :len(cids)] = cids
return {"word_ids": word_ids, "char_ids": char_ids}
def sst2_loader(train_ds, vocab, char_to_id, seq_len, word_len, batch_size):
"""
Yield batched, encoded SST‑2 dataset.
The function takes a HuggingFace ``datasets.Dataset`` containing the
Stanford Sentiment Treebank v2 (SST‑2) data, encodes the textual
component into word‑ and character‑ids, and yields a Python generator
that returns a dictionary of JAX arrays for each mini‑batch.
Parameters
----------
train_ds : :class:`datasets.Dataset`
A HuggingFace ``Dataset`` object that must contain at least two
columns:
``"sentence"`` – raw text data (a list of strings)
``"label"`` – integer labels (0: negative, 1: positive)
vocab : Mapping[str, int]
Word‑to‑ID vocabulary. Must contain the special tokens
``"<pad>"`` and ``"<unk>"`` used by :func:`encode_batch`.
char_to_id : Mapping[str, int]
Character‑to‑ID mapping. Must contain the special tokens
``"<pad>"`` and ``"<unk>"``.
seq_len : int
Maximum number of words per sentence. Sentences longer than this
limit will be truncated; shorter ones padded to ``seq_len``.
word_len : int
Maximum number of characters per word. Characters longer than
this limit are truncated; shorter ones padded to ``word_len``.
batch_size : int
Number of examples per yielded batch.
Yields
------
dict
A dictionary with the following JAX array entries (dtype
``jnp.int32``):
``"char_ids"`` : shape ``(batch_size, seq_len, word_len)``
``"word_ids"`` : shape ``(batch_size, seq_len)``
``"labels"`` : shape ``(batch_size,)``
"""
ds = train_ds.shuffle()
for i in range(0, len(ds), batch_size):
batch = ds[i:i + batch_size]
# Each field is a list
texts = batch["sentence"]
labels = batch["label"]
# Encode text to char IDs
enc = encode_batch(texts, vocab, char_to_id, seq_len, word_len) # These are already batch sized
yield {
"char_ids": jnp.array(enc["char_ids"]),
"word_ids": jnp.array(enc["word_ids"]),
"labels": jnp.array(labels, dtype=jnp.int32),
}
train_loader = sst2_loader(
train_stream,
vocab=vocab,
char_to_id=char_to_id,
seq_len=seq_len,
word_len=word_len,
batch_size=batch_size,
)Phased Fine-Tuning
We fine‑tune the model in two distinct stages in order to avoid large gradients destroying the pre‑trained representation.
Phase 1 as a “High‑level” adaptation:
All parameters of the ELMo encoder, the bi‑LSTM and its learned representations, are frozen.
The only trainable parameters are the scalar‑mix weights, which blend the encoder layers, and the classifier head.
Phase 2 as a “Deep” adaptation:
The bi‑LSTM parameters are now unfrozen so that the encoder can adjust its internal representations to the target task.
The scalar‑mix weights and the classifier head remain trainable.
A very small learning rate is employed to limit catastrophic forgetting.
The following code block explicitly lists the parameter groups updated in each phase.
# Fine-tuning phase 1 updates the parameters of the classifier and the scalar mix parameters in the backbone
trainable_phase1 = nnx.All(
nnx.Param,
nnx.Any(
nnx.PathContains("layer_projections"),
nnx.PathContains("scalar_weights"),
nnx.PathContains("gamma"),
nnx.PathContains("attn_pool"),
nnx.PathContains("classifier"),
)
)
# Fine-tuning phase 2 unfreezes the biltsm layers in the backbone
trainable_phase2 = nnx.All(
nnx.Param,
nnx.Any(
nnx.PathContains("bilstm"),
nnx.PathContains("layer_projections"),
nnx.PathContains("scalar_weights"),
nnx.PathContains("gamma"),
nnx.PathContains("attn_pool"),
nnx.PathContains("classifier"),
)
)Define Training Loop
Next, we set up the training step, initialize the classifier and optimizer, and run the training loop. To keep the gradient updates confined to the intended parameters, we pass our trainable_phase object to create a DiffState.
In the first phase we instantiate two copies of the model:
- one with random‑initialized ELMo backbone weights, and
- one with pre‑trained ELMo weights.
This dual run lets us directly compare performance and confirm that our pre‑training regime is effective. We run phase-1 for 10 epochs.
Below you will notice the large difference in performance between the random weights backbone and the pre-trained backbone; Suggesting that our pre-training regime was indeed effective.
Code
@nnx.jit(static_argnames=("trainable_phase",))
def train_step(model_opt, batch, rng, *, trainable_phase):
"""
Perform one training step for a JAX/NNX model using a custom optimizer.
Parameters
----------
model_opt: OptimizerWrapper
A lightweight optimizer object that holds the current model
parameters
batch: dict[str, jnp.ndarray]
A batch dictionary produced by :func:`sst2_loader`.
rng: jax.random.PRNGKey
Randon number generator
trainable_phase
tells :class:`nnx.DiffState` which part of the model should
get gradients.
Returns
-------
tuple
``(model_opt, loss, acc, grads)``
* ``model_opt`` - the :class:`OptimizerWrapper` after applying
the gradient update.
* ``loss`` (float) - the mean soft‑max cross‑entropy
computed on the current batch.
* ``acc`` (float) - accuracy of the model on this batch.
* ``grads`` - a PyTree of gradients with the same
structure as ``model_opt.model``.
"""
# DiffState must match optimizer wrt argument
diff_state = nnx.DiffState(0, trainable_phase)
def loss_fn(model):
logits = model(batch["char_ids"], deterministic=False, jax_rng=rng)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits, batch["labels"]
).mean()
return loss, logits
(loss, logits), grads = nnx.value_and_grad(
loss_fn,
has_aux=True,
argnums=diff_state,
)(model_opt.model)
model_opt.update(grads)
preds = jnp.argmax(logits, axis=-1)
acc = jnp.mean(preds == batch["labels"])
return model_opt, loss, acc, grads
classifier_random = ElmoClassifier(
model_random,
num_classes=2,
dropout_rate=0.3,
rngs=nnx.Rngs(jax.random.PRNGKey(1))
)
classifier_trained = ElmoClassifier(
model_trained,
num_classes=2,
dropout_rate=0.3,
rngs=nnx.Rngs(jax.random.PRNGKey(1))
)Code
tx = optax.adamw(1e-3, weight_decay=1e-2)
model_opt_random = nnx.ModelAndOptimizer(
classifier_random,
tx,
wrt=trainable_phase1
)
model_opt_trained = nnx.ModelAndOptimizer(
classifier_trained,
tx,
wrt=trainable_phase1
)num_epochs = 10
train_rng = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
print(f"\n===== Epoch {epoch + 1}/{num_epochs} =====")
# reinitialize or reshuffle dataset each epoch
train_loader = sst2_loader(
train_stream, vocab, char_to_id, seq_len, word_len, batch_size
)
epoch_loss_random = []
epoch_loss_trained = []
epoch_acc_random = []
epoch_acc_trained = []
for batch in train_loader:
train_rng, subkey = jax.random.split(train_rng)
model_opt_random, loss_random, acc_random, grads_random = train_step(model_opt_random, batch, rng=subkey, trainable_phase=trainable_phase1)
model_opt_trained, loss_trained, acc_trained, grads_trained = train_step(model_opt_trained, batch, rng=subkey, trainable_phase=trainable_phase1)
epoch_loss_random.append(float(loss_random))
epoch_acc_random.append(float(acc_random))
epoch_loss_trained.append(float(loss_trained))
epoch_acc_trained.append(float(acc_trained))
print(f"Epoch {epoch + 1} | loss={np.mean(epoch_loss_random):.4f} | acc={np.mean(epoch_acc_random):.4f}")
print(f"Epoch {epoch + 1} | loss={np.mean(epoch_loss_trained):.4f} | acc={np.mean(epoch_acc_trained):.4f}")
===== Epoch 1/10 =====
Epoch 1 | loss=0.6870 | acc=0.5525
Epoch 1 | loss=0.6398 | acc=0.6273
===== Epoch 2/10 =====
Epoch 2 | loss=0.6799 | acc=0.5675
Epoch 2 | loss=0.5723 | acc=0.6964
===== Epoch 3/10 =====
Epoch 3 | loss=0.6665 | acc=0.5960
Epoch 3 | loss=0.5337 | acc=0.7277
===== Epoch 4/10 =====
Epoch 4 | loss=0.6487 | acc=0.6239
Epoch 4 | loss=0.5135 | acc=0.7440
===== Epoch 5/10 =====
Epoch 5 | loss=0.6328 | acc=0.6446
Epoch 5 | loss=0.4989 | acc=0.7517
===== Epoch 6/10 =====
Epoch 6 | loss=0.6181 | acc=0.6612
Epoch 6 | loss=0.4839 | acc=0.7643
===== Epoch 7/10 =====
Epoch 7 | loss=0.6073 | acc=0.6708
Epoch 7 | loss=0.4750 | acc=0.7704
===== Epoch 8/10 =====
Epoch 8 | loss=0.5989 | acc=0.6778
Epoch 8 | loss=0.4663 | acc=0.7777
===== Epoch 9/10 =====
Epoch 9 | loss=0.5911 | acc=0.6844
Epoch 9 | loss=0.4594 | acc=0.7807
===== Epoch 10/10 =====
Epoch 10 | loss=0.5864 | acc=0.6887
Epoch 10 | loss=0.4510 | acc=0.7863
Next, we extract the fine-tuned model from the optimizer and instantiate a new model_opt via nnx.ModelAndOptimizer(). Note that for phase two we reduce the learning rate to \(5 \times 10^{-5}\). This lower rate helps preserve the pretrained bidirectional LSTM weights while allowing for small, corrective updates. We then train for five additional epochs in this phase.
Code
tx = optax.adamw(5e-5, weight_decay=1e-2)
model_opt = nnx.ModelAndOptimizer(
model_opt_trained.model,
tx,
wrt=trainable_phase2
)Code
num_epochs = 5
train_rng = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
print(f"\n===== Epoch {epoch + 1}/{num_epochs} =====")
# reinitialize dataset each epoch
train_loader = sst2_loader(
train_stream, vocab, char_to_id, seq_len, word_len, batch_size
)
epoch_loss = []
epoch_acc = []
for batch in train_loader:
train_rng, subkey = jax.random.split(train_rng)
model_opt, loss, acc, grads = train_step(model_opt, batch, rng=subkey, trainable_phase=trainable_phase2)
epoch_loss.append(float(loss))
epoch_acc.append(float(acc))
print(f"Epoch {epoch + 1} | loss={np.mean(epoch_loss):.4f} | acc={np.mean(epoch_acc):.4f}")
===== Epoch 1/5 =====
Epoch 1 | loss=0.4292 | acc=0.8011
===== Epoch 2/5 =====
Epoch 2 | loss=0.4156 | acc=0.8091
===== Epoch 3/5 =====
Epoch 3 | loss=0.4039 | acc=0.8141
===== Epoch 4/5 =====
Epoch 4 | loss=0.3953 | acc=0.8215
===== Epoch 5/5 =====
Epoch 5 | loss=0.3860 | acc=0.8254
Finally, lets check the performance on the validation data.
val_loader = sst2_loader(
val_stream,
vocab=vocab,
char_to_id=char_to_id,
seq_len=seq_len,
word_len=word_len,
batch_size=batch_size,
)
acc = []
for batch in val_loader:
raw_preds = model_opt.model(batch["char_ids"], deterministic=True)
preds = np.argmax(raw_preds, axis=1)
batch_acc = np.sum(batch["labels"] == preds) / preds.shape[0]
acc.append(batch_acc)
print(f"Accuracy on validation: {np.mean(acc): .3f}")Accuracy on validation: 0.742
Conclusion
This was a long one. Good job making it all the way to the end!
In this post we reviewed the differences beween static and contextual embeddings, we delved into the math and architecture of recurrent neural networks, and we implemented ELMo practically from scratch in JAX. In addition to all of that, we also then fine-tined ELMo to perform text classification and we showed the difference in having pre-trained embeddings versus random weights.
Coming Next
In the next post of this series, we will go over the transformer architecture, take a look at how attention works and implement a transformer model from scratch.
I hope to see you there!