You probably hear about jax. As all the cool kids in town use it, I decided to give it a try a while ago. But not until recently, during an experiment with TPU in Pytorch, I ran into a fatal error due to poor support of pytorch for XLA. With my limited knowledge of jax/flax, I decided to port huggingface’s ELECTRA model to flax. This results in my first huggingface pull request. This blog post contains a few tips, hopefully useful, for converting pytorch to flax model.

Why flax

You may ask why I use flax. To be honest, I don’t know. I’m a naive user just like many of you guys. Part of the reason is flax seems to be getting some traction nowadays. Especially, huggingface starts to roll out some support for flax (with FlaxBertModel).

Personally, as a pytorch user (with a painful background in tensorflow), I find flax has a steep learning curve. I don’t really like the abstractions flax introduce, it feels like magic! However, among many competing frameworks in jax space (haiku, trax, objax), I had to pick one, and that’s flax. I may deeply regret my choice in the future, who knows. But c’est la vie!

Okay, let’s start.

Loading pytorch checkpoint

First step, you may want to download the model

generator = ElectraForMaskedLM.from_pretrained(model_name)

The model binary and its JSON config are cached under ~/.cache/huggingface/transformers/ with long filenames (corresponding to Amazon S3 hashes).

You can load the binary to a python dict

import torch
model_file = "~/.cache/huggingface/transformers/blablabla"
with open(model_file, "rb") as state_f:
    pt_state = torch.load(state_f, map_location=torch.device("cpu"))
    pt_state = {k: v.numpy() for k, v in pt_state.items()}

pt_state is a flat python dict, first few keys look like

'electra.embeddings.word_embeddings.weight', 'electra.embeddings.position_embeddings.weight', 'electra.embeddings.token_type_embeddings.weight'

jax/flax uses nested dict to manage model parameters (referred as pytrees) so a conversion is needed

Converting to nested dict

In a beautiful and happy world, we could do

from transformers import FlaxElectraForMaskedLM
from flax.traverse_util import flatten_dict, unflatten_dict

fx_state = FlaxElectraForMaskedLM.convert_from_pytorch(pt_state, config)
fx_state = unflatten_dict({tuple(k.split(".")): v for k, v in fx_state.items()})

Unfortunately, we don’t have FlaxElectraForMaskedLM yet, haha! So we use a similar model to load, for example FlaxBertPreTrainedModel. The key is to override convert_from_pytorch (link) so that our pytorch weights are loaded correctly in flax.

Divide and Conquer to align model weights

The whole effort lies in this part where I basically have to check that every layer is loaded correctly in flax and forward pass is done correctly.

The trick is to use scope to bind flax module. Basically, flax module works in two modes: bound and unbound. In bound mode, it keeps a reference to a scope so it has access to its parameters, and we can examine them. In unbound mode, module is no different than function, parameters are fed as arguments to the __call__ function, the module stores nothing.

from flax.core.scope import Scope
from jax import random

rngkey = random.PRNGKey(42)

# testing on embeddings
scope = Scope({"params": fx_state["embeddings"]}, {"params": rngkey}, mutable=["params"])

layer = FlaxBertEmbeddings(
    vocab_size=config.vocab_size,
    hidden_size=config.embedding_size,
    type_vocab_size=config.type_vocab_size,
    max_length=config.max_position_embeddings,
    parent=scope    
)

# checking param of layer_norm
layer.children["layer_norm"]

# forward pass
x_embed = layer(x)

# check if it's close to output of pytorch model
jnp.allclose(x_embed, x_embed_pt.numpy())

If something goes wrong in this step, you will need to make changes to the above convert_to_pytorch.

Prefer setup() instead of nn.compact

One tip that helps me to debug sub-modules is to use setup instead of nn.compact.

Basically, nn.compact allows us to be lazy. It’s a decorator for the forward pass so that we can declare inlined (and lazy) sub-modules. They will be lazily initialized during the forward pass (not sure why I used so many “lazy” words, I bet it has some correlation with when I wrote this: on a Friday afternoon).

setup initializes sub-modules the moment we create the module (as of now, this hasn’t happened yet, but hopefully soon). We can examine sub-modules via attribute access like dummy.dense

from flax import linen as nn

class Dummy(nn.Module):
    hidden_size: int
        
    def setup(self):
        self.dense = nn.Dense(self.hidden_size)
    
    def __call__(self, x):
        return self.dense(x)
    
dummy = Dummy(hidden_size=5, parent=Scope({}, {"params": rngkey}, mutable=["params"]))
dummy.dense

Common problems to keep in mind

  • flax uses kernel instead of weight for parameter: make sure to rename accordingly
  • sometimes, you have to transpose the weight
  • sometimes, you have to add missing sub-module, such as
class FlaxElectraGeneratorPredictions(nn.Module):
    embedding_size: int
    hidden_act: str = "gelu"
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.dense = nn.Dense(self.embedding_size, dtype=self.dtype)
        self.layer_norm = FlaxElectraLayerNorm(dtype=self.dtype)

    def __call__(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = ACT2FN[self.hidden_act](hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        return hidden_states

Check out my code if you want to know more.

Conclusion

The process of porting a model to flax is time-consuming. I hope this post can alleviate some pains in the process. I decided to give flax a serious try. I may post something about flax or jax in the future.

I appreciate any of your feedbacks or questions, feel free to reach out.