Tips on converting pytorch models to flax
- Why flax
- Loading pytorch checkpoint
- Converting to nested dict
- Divide and Conquer to align model weights
- Prefer setup() instead of nn.compact
- Common problems to keep in mind
- Conclusion
You probably hear about jax. As all the cool kids in town use it, I decided to give it a try a while ago. But not until recently, during an experiment with TPU in Pytorch, I ran into a fatal error due to poor support of pytorch for XLA. With my limited knowledge of jax/flax
, I decided to port huggingface’s ELECTRA model to flax. This results in my first huggingface pull request. This blog post contains a few tips, hopefully useful, for converting pytorch to flax model.
Why flax
You may ask why I use flax
. To be honest, I don’t know. I’m a naive user just like many of you guys. Part of the reason is flax
seems to be getting some traction nowadays. Especially, huggingface starts to roll out some support for flax (with FlaxBertModel).
Personally, as a pytorch user (with a painful background in tensorflow), I find flax
has a steep learning curve. I don’t really like the abstractions flax
introduce, it feels like magic! However, among many competing frameworks in jax
space (haiku, trax, objax), I had to pick one, and that’s flax. I may deeply regret my choice in the future, who knows. But c’est la vie!
Okay, let’s start.
Loading pytorch checkpoint
First step, you may want to download the model
generator = ElectraForMaskedLM.from_pretrained(model_name)
The model binary and its JSON config are cached under ~/.cache/huggingface/transformers/
with long filenames (corresponding to Amazon S3 hashes).
You can load the binary to a python dict
import torch
model_file = "~/.cache/huggingface/transformers/blablabla"
with open(model_file, "rb") as state_f:
pt_state = torch.load(state_f, map_location=torch.device("cpu"))
pt_state = {k: v.numpy() for k, v in pt_state.items()}
pt_state
is a flat python dict, first few keys look like
'electra.embeddings.word_embeddings.weight', 'electra.embeddings.position_embeddings.weight', 'electra.embeddings.token_type_embeddings.weight'
jax/flax
uses nested dict to manage model parameters (referred as pytrees) so a conversion is needed
Converting to nested dict
In a beautiful and happy world, we could do
from transformers import FlaxElectraForMaskedLM
from flax.traverse_util import flatten_dict, unflatten_dict
fx_state = FlaxElectraForMaskedLM.convert_from_pytorch(pt_state, config)
fx_state = unflatten_dict({tuple(k.split(".")): v for k, v in fx_state.items()})
Unfortunately, we don’t have FlaxElectraForMaskedLM
yet, haha! So we use a similar model to load, for example FlaxBertPreTrainedModel
. The key is to override convert_from_pytorch
(link) so that our pytorch weights are loaded correctly in flax
.
Divide and Conquer to align model weights
The whole effort lies in this part where I basically have to check that every layer is loaded correctly in flax
and forward pass is done correctly.
The trick is to use scope to bind flax module. Basically, flax module works in two modes: bound and unbound. In bound mode, it keeps a reference to a scope so it has access to its parameters, and we can examine them. In unbound mode, module is no different than function, parameters are fed as arguments to the __call__
function, the module stores nothing.
from flax.core.scope import Scope
from jax import random
rngkey = random.PRNGKey(42)
# testing on embeddings
scope = Scope({"params": fx_state["embeddings"]}, {"params": rngkey}, mutable=["params"])
layer = FlaxBertEmbeddings(
vocab_size=config.vocab_size,
hidden_size=config.embedding_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
parent=scope
)
# checking param of layer_norm
layer.children["layer_norm"]
# forward pass
x_embed = layer(x)
# check if it's close to output of pytorch model
jnp.allclose(x_embed, x_embed_pt.numpy())
If something goes wrong in this step, you will need to make changes to the above convert_to_pytorch
.
Prefer setup() instead of nn.compact
One tip that helps me to debug sub-modules is to use setup
instead of nn.compact
.
Basically, nn.compact
allows us to be lazy. It’s a decorator for the forward pass so that we can declare inlined (and lazy) sub-modules. They will be lazily initialized during the forward pass (not sure why I used so many “lazy” words, I bet it has some correlation with when I wrote this: on a Friday afternoon).
setup
initializes sub-modules the moment we create the module (as of now, this hasn’t happened yet, but hopefully soon). We can examine sub-modules via attribute access like dummy.dense
from flax import linen as nn
class Dummy(nn.Module):
hidden_size: int
def setup(self):
self.dense = nn.Dense(self.hidden_size)
def __call__(self, x):
return self.dense(x)
dummy = Dummy(hidden_size=5, parent=Scope({}, {"params": rngkey}, mutable=["params"]))
dummy.dense
Common problems to keep in mind
-
flax
uses kernel instead of weight for parameter: make sure to rename accordingly - sometimes, you have to transpose the weight
- sometimes, you have to add missing sub-module, such as
class FlaxElectraGeneratorPredictions(nn.Module):
embedding_size: int
hidden_act: str = "gelu"
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(self.embedding_size, dtype=self.dtype)
self.layer_norm = FlaxElectraLayerNorm(dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
Check out my code if you want to know more.
Conclusion
The process of porting a model to flax
is time-consuming. I hope this post can alleviate some pains in the process. I decided to give flax
a serious try. I may post something about flax
or jax
in the future.
I appreciate any of your feedbacks or questions, feel free to reach out.