Bumblebee is an Elixir library that gives you access to some very powerful models capable of incredible things. At times, it really can seem like magic. Bumblebee abstracts many of the details of building models, turning real-world inputs into tensors, and turning tensors into real-world outputs away from you. Bumblebee’s abstractions simplify your implementations; however, the library can feel like a black box.
In this post, we’ll attempt to peel back the layers of abstraction by implementing GPT-2 in pure Nx. We’ll also look at how text generation works, and how we can easily wrap our implementation in an Nx.Serving
for scalable and distributed model serving.
This post is based on this fantastic blog post by Jay Mody which does the same thing but with NumPy. We won’t be code golfing here, but I think you’ll find the Nx implementation is just as straightforward as the NumPy implementation.
Before getting started, you’ll need to install a few dependencies:
Mix.install([
{:nx, "~> 0.6"},
{:exla, "~> 0.6"},
{:safetensors, "~> 0.1"},
{:tokenizers, "~> 0.3"}
])
Nx is Elixir’s numerical computing library. It is the foundation of every other library and implements the key data structure for machine learning and numerical computing: %Nx.Tensor
. EXLA is a just-in-time compiler for Nx. It transforms Nx numerical definitions defn
into optimized CPU and GPU programs. safetensors is an Elixir library that allows us to convert .safetensors
files into collections of Nx tensors. We’ll use safetensors to load pre-trained GPT-2 parameters. Finally, the tokenizers library is an Elixir wrapper around the HuggingFace tokenizers library. It will allow us to use pre-trained tokenizers for encoding strings as tensors and decoding tensors as strings.
Loading Parameters
The first step in creating our GPT-2 model is loading pre-trained model parameters as Nx tensors. One of the things that makes Bumblebee so awesome is its ability to take pre-trained model parameters from the Python ecosystem and use them directly in Elixir-based implementations of the same models. There are two ways we do this:
- Converting PyTorch model parameters (
.bin
files representing Python pickles) into Nx tensors using the unpickler library from Dashbit. - Converting safetensors model parameters (
.safetensors
files) into Nx tensors using the safetensors library.
Recently, HuggingFace has started converting a majority of pre-trained model parameters to the safetensors format because it offers a variety of benefits. Fortunately, it also simplifies a lot of our parameter conversion process. For this tutorial, we’ll be using the .safetensors
version of the smallest GPT-2 pre-trained model. You can download these parameters here: [https://huggingface.co/gpt2/blob/main/model.safetensors](https://huggingface.co/gpt2/blob/main/model.safetensors).
params = Safetensors.load!(File.read!("gpt2.safetensors"))
%{
"h.4.mlp.c_fc.weight" => #Nx.Tensor<
f32[768][3072]
[
[-4.077995545230806e-4, -0.1200379952788353, -0.012310190126299858, -0.24376054108142853, 0.1328510195016861, 0.13179974257946014, 0.02635245770215988, 0.057356610894203186, -0.06828179955482483, -0.01686907559633255, 0.049044106155633926, -0.3784016966819763, -0.03531080484390259, 0.43567171692848206, 0.02976839430630207, -0.06014109030365944, 0.18706151843070984, -0.050236742943525314, 0.11668948084115982, 0.05957753583788872, -0.14054043591022491, -0.013522407039999962, -0.06838822364807129, 0.06603340059518814, 0.1704917997121811, -0.04796731844544411, 0.13489077985286713, 0.09338540583848953, -0.4719774127006531, -0.023174606263637543, -0.035188376903533936, 0.0310268085449934, 0.10190644860267639, 0.12008801102638245, -0.10986845195293427, 0.24427762627601624, -0.26449140906333923, 0.03904227167367935, -0.0342307947576046, 0.003329918021336198, 0.10291037708520889, 0.022748058661818504, 0.012120273895561695, 0.017654111608862877, 0.061049483716487885, -0.25464147329330444, 0.017738372087478638, -0.1125129833817482, 0.019292231649160385, ...],
...
]
>,
...
}
You’ll notice the model parameters consist of a map of strings to tensors. The map’s keys will have values like: h.4.mlp.c_fc.weight
. The last value in the .
delimited string represents the parameter name of that particular tensor. The preceeding values map to a nested modules. Essentially, if you had a module like:
class MyModule(nn.Module):
def __init__(self, *args):
self.fc1 = nn.Linear(64, 32)
def forward(self, x):
return F.relu(self.fc1(x))
in PyTorch, this would map to the parameter keys {module_name}.fc1.weight
and {module_name}.fc1.bias}
. In order to work with these parameters, we need to “unflatten” them so that they’re easy to work with in our pure Nx implementation. This is one of the first things Bumblebee does automatically for you — it maps this flattened representation of a model’s parameters to a representation Axon can use.
For this example, we’ll convert the flattened map into a nested map where each value represents the input parameters to a function or layer in our GPT-2 implementation:
blocks =
Enum.reduce(params, %{}, fn {key, value}, acc ->
case String.split(key, ".") do
["h", block_num, inner_block_name, layer_name, param_name] ->
init = %{inner_block_name => %{layer_name => %{param_name => value}}}
Map.update(acc, "block_#{block_num}", init, fn block_params ->
inner_init = %{layer_name => %{param_name => value}}
Map.update(block_params, inner_block_name, inner_init, fn inner_block_params ->
layer_init = %{param_name => value}
Map.update(inner_block_params, layer_name, layer_init, fn layer_params ->
Map.put(layer_params, param_name, value)
end)
end)
end)
["h", block_num, layer_name, param_name] ->
init = %{layer_name => %{param_name => value}}
Map.update(acc, "block_#{block_num}", init, fn block_params ->
layer_init = %{param_name => value}
Map.update(block_params, layer_name, layer_init, fn layer_params ->
Map.put(layer_params, param_name, value)
end)
end)
[layer_name, param_name] ->
Map.update(acc, layer_name, %{param_name => value}, fn layer_params ->
Map.put(layer_params, param_name, value)
end)
end
end)
%{
"block_0" => %{
"attn" => %{
"bias" => #Nx.Tensor<
f32[1][1][1024][1024]
[
[
[
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
...
]
]
]
>,
"c_attn" => %{...}
...
}
}
}
After running this, you’ll have a nested map which we can recurse and use to build our model.
Encoding Text
Another great convenience Bumblebee offers is its handling of pre/post-processing for machine-learning tasks for you. When dealing with text, this manifests in the form of tokenization. Language models rely on tokenizers to convert sequences of input into discrete numerical representations of the text. Most often, this comes in the form of subword tokenization. Models like GPT-2 rely on probabilistic techniques to build fixed-size vocabularies of subwords. These vocabularies are then used to translate text representations into sequences of integers and vice-versa.
Bumblebee makes use of HuggingFace’s tokenizer library for much of the pre/post-processing related to text-based models like GPT-2. For this tutorial, all we really need is a module that:
- Can instantiate a new tokenizer from a pre-trained one
- Encode text as a tensor
- Decode integers to text
We can implement our encoder like this:
defmodule Encoder do
defstruct [:tokenizer]
def new(model_id) do
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained(model_id)
%__MODULE__{tokenizer: tokenizer}
end
def encode(%{tokenizer: tokenizer}, text) do
{:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)
Nx.tensor(Tokenizers.Encoding.get_ids(encoding))
|> Nx.new_axis(0)
end
def decode(%{tokenizer: tokenizer}, id) do
{:ok, token} = Tokenizers.Tokenizer.decode(tokenizer, [id])
token
end
end
Now, we can instantiate a new “encoder” using Encoder.new/1
and passing in a HuggingFace tokenizer id. This id maps to a HuggingFace model repository:
encoder = Encoder.new("gpt2")
%Encoder{
tokenizer: #Tokenizers.Tokenizer<[
vocab_size: 50257,
byte_fallback: false,
continuing_subword_prefix: "",
dropout: nil,
end_of_word_suffix: "",
fuse_unk: false,
model_type: "bpe",
unk_token: nil
]>
}
Then, we can use this encoder to encode text as tensors:
Encoder.encode(encoder, "Hello world!")
#Nx.Tensor<
s64[1][3]
[
[15496, 995, 0]
]
>
And decode integers to text:
Encoder.decode(encoder, 15496)
"Hello"
And that’s really all we need! Bumblebee also takes care of things like attention masks, padding inputs, etc. but those are outside the scope of this blog post.
Modeling Transformers with Nx
Now it’s time to model GPT-2 as Nx numerical definitions. Nx numerical definitions defn
are just functions that support a special subset of the Elixir programming language. defn
is meant to be a boundary for pure Nx code because everything within a defn
has the potential to be JIT compiled. We can build our model up such that each “layer” or “block” maps to a numerical definition. Transformer models like GPT are relatively simple and consist of a few steps:
- Compute embedding from sequence (positional and token)
- A series of “transformer” blocks
- A final layer normalization and linear transformation
We can start by implementing our “top-level” model with this flow:
defmodule GPT do
import Nx.Defn
defn predict(input, wte, wpe, blocks, ln_f, opts \\ []) do
opts = keyword!(opts, [n_head: 12])
input
|> embedding(wte, wpe)
|> transformer(blocks, n_head: opts[:n_head])
|> layer_norm(ln_f)
|> Nx.dot([-1], wte["weight"], [-1])
end
end
That’s really all our high-level model does. Of course, we still need to implement many of these layers. Let’s start with the embedding layer:
defn embedding(x, %{"weight" => wte}, %{"weight" => wpe}) do
position_ids = Nx.iota({Nx.axis_size(x, 0), Nx.axis_size(x, 1)}, axis: -1)
Nx.take(wte, x) + Nx.take(wpe, position_ids)
end
Our embedding layer consists of a token embedding and a positional embedding. The token embedding learns to map input tokens in our vocab to vectors. The operation is essentially a “lookup” of the embedding at the index given by the given integer ID in the input sequence. The positional embedding does exactly the same thing; however, the IDs map to the position in the input sequence. The idea is that this embedding encodes both semantic information from the token embedding and positional information from the positional embedding.
Now things get interesting—we need to implement our transformer
. The transformer
portion of the model consists of a number of transformer_blocks
applied successively. In other words, the transformer implementation looks like this:
deftransform transformer(x, params, opts \\ []) do
n_head = opts[:n_head]
Enum.reduce(params, x, fn {_block_name, block_params}, x ->
transformer_block(x, block_params, n_head: n_head)
end)
end
This layer is implemented as a deftransform
which is an escape hatch for using regular Elixir functions inside numerical definitions. Essentially, deftransform
operates on Nx expressions—which means we can use things like Enum.reduce
to lazily build up our operations. We need to use it here so we can apply transformer_block
for each block in our input parameters to the input x
.
Now we can implement transformer_block
. The transformer_block
consists of:
- Layer normalization
- Multi-head self-attention
- Residual
- Layer normalization
- Point-wise feed-forward network (FFN)
We can implement our transformer block then like:
defn transformer_block(
x,
%{"mlp" => mlp, "attn" => attn, "ln_1" => ln_1, "ln_2" => ln_2},
opts \\ []
) do
opts = keyword!(opts, [n_head: 12])
attention =
x
|> layer_norm(ln_1)
|> mha(attn, n_head: opts[:n_head])
attention + ffn(layer_norm(attention, ln_2), mlp)
end)
end
And that’s pretty much it! Okay, so now we need to go about implementing the layers used here. We’ll start with mha
or multi-head self-attention as that is the “meat” of the transformer. There is a lot of literature on what exactly attention is, so I will omit lengthy explanations and mathematical details. Attention is an operation that computes the relative importance of tokens in two sequences to one another.
Essentially, we compute a relationship matrix between two sequences. In transformers, we use something called self-attention to compute the relationship between an input sequence and itself. The “multi-head” terminology comes in as we split the embedded representation of our text into multiple heads to compute attention multiple times. It’s kind of like an ensembling technique. Rather than getting a single relationship (attention) matrix per transformer block, we get multiple.
Our mha
implementation consists of the following steps:
- Compute a linear projection of the input
- Split the input into
query
,key
, andvalue
tensors - Split
q (query)
,k (key)
, andv (value)
to use multiple heads - Compute a
causal_mask
- Perform attention operation
- Compute output projection
In code this looks like:
defn mha(x, %{"c_attn" => c_attn, "c_proj" => c_proj}, opts \\ []) do
opts = keyword!(opts, [n_head: 12])
x = linear(x, c_attn)
{q, k, v} = split_qkv(x)
q = split_heads(q, opts[:n_head])
k = split_heads(k, opts[:n_head])
v = split_heads(v, opts[:n_head])
causal_mask = (1 - Nx.tri(Nx.axis_size(x, 0), Nx.axis_size(x, 0))) * -1.0e10
out = attention(q, k, v, causal_mask)
linear(out, c_proj)
end
From here, we need to implement the logic to split x
into {q, k, v}
, like this:
deftransformp split_qkv(tensor) do
split_size = div(Nx.axis_size(tensor, -1), 3)
q = tensor[[.., .., 0..(split_size - 1)]]
k = tensor[[.., .., split_size..(2*split_size - 1)]]
v = tensor[[.., .., 2*split_size..-1//1]]
{q, k, v}
end
This essentially slices the input tensor into 3 distinct tensors that will be used for our attention operation. Next, we need to implement the logic for reshaping q
, k
, and v
into multiple heads:
deftransformp split_heads(tensor, n_heads) do
{batch, seq, _dim} = Nx.shape(tensor)
Nx.reshape(tensor, {batch, seq, n_heads, :auto})
end
Finally, we need to implement our actual attention operation:
defn attention(q, k, v, mask) do
k = Nx.transpose(k, axes: [0, 2, 1, 3])
q = Nx.transpose(q, axes: [0, 2, 1, 3])
v = Nx.transpose(v, axes: [0, 2, 1, 3])
q
|> Nx.divide(Nx.sqrt(Nx.axis_size(q, -1)))
|> Nx.dot([3], [0, 1], k, [3], [0, 1])
|> softmax()
|> Nx.add(mask)
|> Nx.dot([3], [0, 1], v, [2], [0, 1])
|> Nx.transpose(axes: [0, 2, 1, 3])
|> flatten_heads()
end
The attention implementation mostly just consists of some shape manipulations and dot-products. There is a softmax
operation which is used to normalize attention weights prior to computing the final attention output. We’ll implement that in a bit. First, let’s implement flatten_heads
which combines our multiple attention heads into a single output:
deftransformp flatten_heads(tensor) do
shape = Nx.shape(tensor)
rank = Nx.rank(tensor)
new_shape =
shape
|> Tuple.delete_at(rank - 1)
|> put_elem(rank - 2, :auto)
Nx.reshape(tensor, new_shape)
end
Now we need to go back and fill in some of our other missing pieces. We’ll start with ffn
. ffn
represents a basic feed-forward neural network (multi-layer perceptron):
defn ffn(x, %{"c_fc" => c_fc, "c_proj" => c_proj}) do
x
|> linear(c_fc)
|> gelu()
|> linear(c_proj)
end
This consists of a projection up with a “linear” or “dense” layer followed by a GeLU activation function and then a projection down with another linear layer. Next, we can implement our linear layer like this:
@doc """
Linear layer.
## Examples
iex> {x, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {32, 128})
iex> {w, key} = Nx.Random.uniform(key, shape: {128, 256})
iex> {b, _key} = Nx.Random.uniform(key, shape: {256})
iex> out = GPT.linear(x, %{"weight" => w, "bias" => b})
iex> Nx.shape(out)
{32, 256}
iex> Nx.type(out)
{:f, 32}
"""
defn linear(x, %{"weight" => w, "bias" => b}) do
b + Nx.dot(x, w)
end
Note that I’ve added doctests to this implementation. When working in a Livebook, you can add these doctests and they will automatically run when the module is compiled! This will help you check your implementation as you move forward.
Our linear layer essentially just applies a linear transformation with an input weight and then adds a bias to the input.
The final “layer” we need to implement (besides our activation functions) is layer normalization. Layer normalization normalizes the input and then scales and shifts the normalized input according to learned parameters:
@doc """
Applies Layer Normalization.
## Examples
iex> x = Nx.tensor([[2, 2, 3], [-5, 0, 1]])
iex> actual = GPT.layer_norm(x, %{"weight" => Nx.broadcast(1.0, {2, 1}), "bias" => Nx.broadcast(0.0, {2, 1})})
iex> expected = Nx.tensor([
...> [-0.70709, -0.70709, 1.41418],
...> [-1.397, 0.508, 0.889]
...> ])
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn layer_norm(x, %{"weight" => w, "bias" => b}, opts \\ []) do
opts = keyword!(opts, [eps: 1.0e-5])
mean = Nx.mean(x, axes: [-1], keep_axes: true)
std_dev = Nx.standard_deviation(x, axes: [-1] , keep_axes: true)
x = (x - mean) / (std_dev + opts[:eps])
w * x + b
end
With our layers implemented, we just need to implement the two “activation” functions used in our implementation and then our model is ready! Activation functions are point-wise (applied to each entry in a tensor) non-linear functions. We’ll start with GeLU:
@doc """
Applies GeLU Activation.
## Examples
iex> actual = GPT.gelu(Nx.tensor([[1, 2], [-2, 0.5]]))
iex> expected = Nx.tensor(([[0.84119, 1.9546], [-0.0454, 0.34571]]))
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn gelu(x) do
gaussian_const = Nx.sqrt(2 / Nx.Constants.pi())
0.5 * x * (1 + Nx.tanh(gaussian_const * (x + 0.044715 * x ** 3)))
end
And then Softmax:
@doc """
Applies Softmax Activation.
## Examples
iex> actual = GPT.softmax(Nx.tensor([[2, 100], [-5, 0]]))
iex> expected = Nx.tensor([[2.74878501e-43, 1.0],[6.69285092e-03, 9.93307149e-01]])
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn softmax(x) do
exp_x = Nx.exp(x - Nx.reduce_max(x, axes: [-1], keep_axes: true))
exp_x / Nx.sum(exp_x, axes: [-1], keep_axes: true)
end
And that’s it! Overall, your GPT module should look like this:
defmodule GPT do
import Nx.Defn
defn predict(input, wte, wpe, blocks, ln_f, opts \\ []) do
opts = keyword!(opts, n_head: 12)
input
|> embedding(wte, wpe)
|> transformer(blocks, n_head: opts[:n_head])
|> layer_norm(ln_f)
|> Nx.dot([-1], wte["weight"], [-1])
end
defn embedding(x, %{"weight" => wte}, %{"weight" => wpe}) do
position_ids = Nx.iota({Nx.axis_size(x, 0), Nx.axis_size(x, 1)}, axis: -1)
Nx.take(wte, x) + Nx.take(wpe, position_ids)
end
deftransform transformer(x, params, opts \\ []) do
Enum.reduce(params, x, fn {_block_name, block_params}, x ->
transformer_block(x, block_params, n_head: opts[:n_head])
end)
end
defn transformer_block(
x,
%{"mlp" => mlp, "attn" => attn, "ln_1" => ln_1, "ln_2" => ln_2},
opts \\ []
) do
opts = keyword!(opts, n_head: 12)
x
|> layer_norm(ln_1)
|> mha(attn, n_head: opts[:n_head])
|> then(fn x ->
x
|> layer_norm(ln_2)
|> ffn(mlp)
|> Nx.add(x)
end)
end
defn mha(x, %{"c_attn" => c_attn, "c_proj" => c_proj}, opts \\ []) do
opts = keyword!(opts, n_head: 12)
x = linear(x, c_attn)
{q, k, v} = split_qkv(x)
q = split_heads(q, opts[:n_head])
k = split_heads(k, opts[:n_head])
v = split_heads(v, opts[:n_head])
causal_mask = (1 - Nx.tri(Nx.axis_size(x, 0), Nx.axis_size(x, 0))) * -1.0e10
out = attention(q, k, v, causal_mask)
linear(out, c_proj)
end
deftransformp split_qkv(tensor) do
split_size = div(Nx.axis_size(tensor, -1), 3)
q = tensor[[0..-1//1, 0..-1//1, 0..(split_size - 1)]]
k = tensor[[0..-1//1, 0..-1//1, split_size..(2 * split_size - 1)]]
v = tensor[[0..-1//1, 0..-1//1, (2 * split_size)..-1//1]]
{q, k, v}
end
deftransformp split_heads(tensor, n_head) do
{batch, seq, _dim} = Nx.shape(tensor)
Nx.reshape(tensor, {batch, seq, n_head, :auto})
end
defn attention(q, k, v, mask) do
k = Nx.transpose(k, axes: [0, 2, 1, 3])
q = Nx.transpose(q, axes: [0, 2, 1, 3])
v = Nx.transpose(v, axes: [0, 2, 1, 3])
q
|> Nx.divide(Nx.sqrt(Nx.axis_size(q, -1)))
|> Nx.dot([3], [0, 1], k, [3], [0, 1])
|> softmax()
|> Nx.add(mask)
|> Nx.dot([3], [0, 1], v, [2], [0, 1])
|> Nx.transpose(axes: [0, 2, 1, 3])
|> flatten_heads()
end
deftransformp flatten_heads(tensor) do
shape = Nx.shape(tensor)
rank = Nx.rank(tensor)
new_shape =
shape
|> Tuple.delete_at(rank - 1)
|> put_elem(rank - 2, :auto)
Nx.reshape(tensor, new_shape)
end
defn ffn(x, %{"c_fc" => c_fc, "c_proj" => c_proj}) do
x
|> linear(c_fc)
|> gelu()
|> linear(c_proj)
end
@doc """
Linear layer.
## Examples
iex> {x, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {32, 128})
iex> {w, key} = Nx.Random.uniform(key, shape: {128, 256})
iex> {b, _key} = Nx.Random.uniform(key, shape: {256})
iex> out = GPT.linear(x, %{"weight" => w, "bias" => b})
iex> Nx.shape(out)
{32, 256}
iex> Nx.type(out)
{:f, 32}
"""
defn linear(x, %{"weight" => w, "bias" => b}) do
x |> Nx.dot(w) |> Nx.add(b)
end
@doc """
Applies Layer Normalization.
## Examples
iex> x = Nx.tensor([[2, 2, 3], [-5, 0, 1]])
iex> actual = GPT.layer_norm(x, %{"weight" => Nx.broadcast(1.0, {2, 1}), "bias" => Nx.broadcast(0.0, {2, 1})})
iex> expected = Nx.tensor([
...> [-0.70709, -0.70709, 1.41418],
...> [-1.397, 0.508, 0.889]
...> ])
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn layer_norm(x, %{"weight" => w, "bias" => b}, opts \\ []) do
opts = keyword!(opts, eps: 1.0e-5)
mean = Nx.mean(x, axes: [-1], keep_axes: true)
variance = Nx.variance(x, axes: [-1], keep_axes: true)
x = (x - mean) / Nx.sqrt(variance + opts[:eps])
w * x + b
end
@doc """
Applies GeLU Activation.
## Examples
iex> actual = GPT.gelu(Nx.tensor([[1, 2], [-2, 0.5]]))
iex> expected = Nx.tensor(([[0.84119, 1.9546], [-0.0454, 0.34571]]))
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn gelu(x) do
0.5 * x * (1 + Nx.tanh(Nx.sqrt(2 / Nx.Constants.pi()) * (x + 0.044715 * Nx.pow(x, 3))))
end
@doc """
Applies Softmax Activation.
## Examples
iex> actual = GPT.softmax(Nx.tensor([[2, 100], [-5, 0]]))
iex> expected = Nx.tensor([[2.74878501e-43, 1.0],[6.69285092e-03, 9.93307149e-01]])
iex> Nx.all_close(actual, expected, atol: 1.0e-3)
#Nx.Tensor<
u8
1
>
"""
defn softmax(x) do
exp_x = Nx.exp(x - Nx.reduce_max(x, axes: [-1], keep_axes: true))
exp_x / Nx.sum(exp_x, axes: [-1], keep_axes: true)
end
end
4 doctests, 0 failures
{:module, GPT, <<70, 79, 82, 49, 0, 0, 59, ...>>, true}
With our model implemented, we can create a predict function which takes input parameters and produces an output tensor:
predict_fun = fn input, params ->
{wte, params} = Map.pop!(params, "wte")
{wpe, params} = Map.pop!(params, "wpe")
{ln_f, params} = Map.pop!(params, "ln_f")
GPT.predict(input, wte, wpe, params, ln_f)
end
#Function<41.3316493/2 in :erl_eval.expr/6>
And of course, we’ll want to JIT compile our predict function so it runs accelerated:
predict_fun = Nx.Defn.jit(predict_fun, compiler: EXLA)
#Function<134.64864510/2 in Nx.Defn.Compiler.fun/2>
Now we can get some input and pass it to our model!
input = Encoder.encode(encoder, "Hello World!")
predict_fun.(input, blocks)
#Nx.Tensor<
f32[1][3][50257]
EXLA.Backend<host:0, 0.879302795.2891055124.147515>
[
[
[-17.95338249206543, -10.991355895996094, -13.45356273651123, -19.010509490966797, -22.15827751159668, -19.9477596282959, -14.55197525024414, -11.953147888183594, -12.225095748901367, -22.248506546020508, -14.363884925842285, -9.260599136352539, -10.11651611328125, -12.922428131103516, -14.436551094055176, -12.716976165771484, -10.315630912780762, -14.780750274658203, -10.36999797821045, -11.931602478027344, -15.256696701049805, -12.662351608276367, -13.149083137512207, -13.825837135314941, -18.168048858642578, -11.906233787536621, -18.675334930419922, -15.466197967529297, -15.000768661499023, -15.671730041503906, -15.362894058227539, -10.682412147521973, -13.200881958007812, -15.150506019592285, -14.514249801635742, -9.522102355957031, -10.189918518066406, -14.229316711425781, -13.961867332458496, -13.763506889343262, -10.999066352844238, -14.516007423400879, -14.772758483886719, -12.857023239135742, -13.77109146118164, -10.552913665771484, -12.120234489440918, -18.59151840209961, -13.763373374938965, -12.64120101928711, ...],
...
]
]
>
And that’s it! You just implemented GPT-2 in pure Nx! Now we can use this model to generate some text.
Generating Text
The output for the GPT-2 model seems like just a random tensor; however, we can use it to generate text. You’ll notice the final dimension of the tensor has a size of 50257
. This actually maps to the exact size of the GPT-2 vocabulary. We can turn our model output into a next token prediction with the following computation:
output = predict_fun.(input, blocks)
logits = output[[.., -1]]
next_token = Nx.argmax(logits, axis: -1)
#Nx.Tensor<
s64[1]
EXLA.Backend<host:0, 0.879302795.2891055124.147519>
[84]
>
This essentially grabs the next token logits from the output tensor and then computes the argmax
of the logits tensor which represents the ID of the next token. We can continuously add this to our input sequence to repeatedly get next token predictions from our model. This is pretty easily modeled with Enum.reduce_while
:
defmodule Generator do
def generate(predict_fun, encoder, input, params, eos_id, max_seq_len) do
encoded_input = Encoder.encode(encoder, input)
seq_len = Nx.axis_size(encoded_input, 1)
Enum.reduce_while(seq_len..max_seq_len, encoded_input, fn _idx, current_input ->
output = predict_fun.(current_input, params)
logits = output[[.., -1]]
next_token = Nx.argmax(logits, axis: -1, keep_axis: true)
if eos_id == Nx.to_number(Nx.squeeze(next_token)) do
{:halt, current_input}
else
IO.write("#{Encoder.decode(encoder, Nx.to_number(Nx.squeeze(next_token)))}")
new_sequence = Nx.concatenate([current_input, next_token], axis: -1)
{:cont, new_sequence}
end
end)
end
end
{:module, Generator, <<70, 79, 82, 49, 0, 0, 12, ...>>, {:generate, 6}}
This generation function continuously predicts the next token and creates a new sequence from it. It will stop predicting if the model outputs its eos
or end-of-sequence token OR if it reaches the specified max sequence length. For this example, we decode tokens at each step just to inspect what our model is doing:
Generator.generate(predict_fun, encoder, "Elixir is", blocks, 50256, 256)
Now, if you know a thing or two about Nx and JIT compilation, you’ll know this implementation is inefficient because we have to compile a new computation at every step because the input shape changes. This is meant to be a simple tutorial—Bumblebee uses more complex implementations that do not lead to recompilations, support streaming, and a bunch of other things: https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/text/generation.ex
Servings and Inference
Perhaps the best thing Bumblebee has to offer is a collection of pre-defined servings. Servings encapsulate machine learning tasks and can be easily added to your application’s supervision tree and used anywhere in your application for scalable machine learning inference. Servings are really just datastructures that encapsulate preprocessing, inference, and postprocessing. A serving for this model might look like:
serving =
Nx.Serving.new(fn _, _ -> &predict_fun.(&1, blocks) end)
|> Nx.Serving.client_preprocessing(fn input ->
{Nx.Batch.concatenate([Encoder.encode(encoder, input)]), :ok}
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, :server_info}, _meta ->
token_ids[[.., -1]]
|> Nx.argmax(axis: -1)
|> Nx.squeeze()
|> Nx.to_number()
|> then(&Encoder.decode(encoder, &1))
end)
%Nx.Serving{
module: Nx.Serving.Default,
arg: #Function<41.3316493/2 in :erl_eval.expr/6>,
client_preprocessing: #Function<42.3316493/1 in :erl_eval.expr/6>,
client_postprocessing: #Function<41.3316493/2 in :erl_eval.expr/6>,
streaming: nil,
batch_size: nil,
distributed_postprocessing: &Function.identity/1,
process_options: [],
defn_options: []
}
Nx.Serving.run(serving, "Hello world")
"о"
Nx.Serving
is a powerful abstraction that supports load balancing and distribution by default. It also supports dynamic batch inference. The process of creating a serving that handles dynamic batching of requests for you is as simple as:
Supervisor.start_link(
[
{Nx.Serving, name: GPTInference, serving: serving}
],
strategy: :one_for_one
)
{:ok, #PID<0.1036.0>}
And then you can get inferences from the named process:
Nx.Serving.batched_run(GPTInference, "Hello!")
" too"
You can use it to transform your simple Nx functions into scalable machine learning deployments embedded directly in Phoenix applications. I highly suggest reading more about it here: https://hexdocs.pm/nx/Nx.Serving.html
Conclusion
And that’s it! In this post, you implemented GPT-2 from scratch using just Nx. Hopefully, this gives you some intuition (and appreciation) of what Bumblebee is doing under the hood. Until next time!
Elixir can be the game-changer you need to put your digital product ahead of the competition. Contact us today to learn how we can put it to work for you.