# DALLÂ·E mini - Embedding Retrain Preparation

We'll start with the dalle-mini model for faster experimentation.

In [1]:
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"  # can be wandb artifact or ðŸ¤— Hub or local folder or google bucket
DALLE_COMMIT_ID = None

# # dalle-mega
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  # can be wandb artifact or ðŸ¤— Hub or local folder or google bucket
# DALLE_COMMIT_ID = None

In [2]:
import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

8

## Load model

We load the model twice to keep a copy of the original parameters.

In [3]:
# Load model
from dalle_mini import DalleBart, DalleBartProcessor

model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

_, params_original = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

[34m[1mwandb[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2
tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @  0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4
[34m[1mwandb[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2
tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @  0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x

## Model surgery: remove layers to be retrained

Let's take a look at the params tree.

In [4]:
sum(x.size for x in jax.tree_leaves(params))

437833712

In [5]:
import json
tree = jax.tree_map(lambda x: x.shape, params)
print(json.dumps(tree, indent=2))

{
  "lm_head": {
    "kernel": [
      1024,
      16385
    ]
  },
  "model": {
    "decoder": {
      "embed_positions": {
        "embedding": [
          256,
          1024
        ]
      },
      "embed_tokens": {
        "embedding": [
          16385,
          1024
        ]
      },
      "final_ln": {
        "bias": [
          1024
        ]
      },
      "layernorm_embedding": {
        "bias": [
          1024
        ],
        "scale": [
          1024
        ]
      },
      "layers": {
        "FlaxBartDecoderLayers": {
          "FlaxBartAttention_0": {
            "k_proj": {
              "kernel": [
                12,
                1024,
                1024
              ]
            },
            "out_proj": {
              "kernel": [
                12,
                1024,
                1024
              ]
            },
            "q_proj": {
              "kernel": [
                12,
                1024,
                1024
              

We will remove or reinitialize:
- `lm_head`
- `model.decoder.embed_positions`
- `model.decoder.embed_tokens`
- `model.decoder.final_ln`
- `model.decoder.layernorm_embedding`

In [6]:
del params['lm_head']
for layer in ['embed_positions', 'embed_tokens', 'final_ln', 'layernorm_embedding']:
    del params['model']['decoder'][layer]

In [7]:
jax.tree_map(lambda x: x.shape, params)

{'model': {'decoder': {'layers': {'FlaxBartDecoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,
        1024,
        1024)},
      'out_proj': {'kernel': (12, 1024, 1024)},
      'q_proj': {'kernel': (12, 1024, 1024)},
      'v_proj': {'kernel': (12, 1024, 1024)}},
     'FlaxBartAttention_1': {'k_proj': {'kernel': (12, 1024, 1024)},
      'out_proj': {'kernel': (12, 1024, 1024)},
      'q_proj': {'kernel': (12, 1024, 1024)},
      'v_proj': {'kernel': (12, 1024, 1024)}},
     'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},
      'Dense_1': {'kernel': (12, 1024, 2730)},
      'Dense_2': {'kernel': (12, 2730, 1024)},
      'LayerNorm_0': {'bias': (12, 1024)},
      'LayerNorm_1': {'bias': (12, 2730)}},
     'LayerNorm_0': {'bias': (12, 1024)},
     'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)},
     'LayerNorm_2': {'bias': (12, 1024)},
     'LayerNorm_3': {'bias': (12, 1024), 'scale': (12, 1024)}}}},
  'encoder': {'embed_positions': {'embedding': (64, 1024)},

In [8]:
sum(x.size for x in jax.tree_leaves(params))

404012016

## Reinitialize layers

We save a checkpoint and reload it again. It does not automatically reinitialize the missing keys, but it sets `_missing_keys` appropriately so we can initialize them later. We could do the same by simply setting that property ourselves, but I'll refrain from doing so because it's a private implementation detail.

In [9]:
trimmed_checkpoint = "mini-trimmed"

In [10]:
model.save_pretrained(trimmed_checkpoint, params=params)

tcmalloc: large alloc 1610424320 bytes == 0x5632d11c8000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571
tcmalloc: large alloc 3231449088 bytes == 0x56333119a000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d

In [11]:
model, params = DalleBart.from_pretrained(
    trimmed_checkpoint, revision=None, dtype=jnp.float16, _do_init=False
)

The checkpoint mini-trimmed is missing required keys: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}. Make sure to call model.init_weights to initialize the missing weights.
Some weights of DalleBart were not initialized from the model checkpoint at mini-trimmed and are newly initialized: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
model._missing_keys

{('lm_head', 'kernel'),
 ('model', 'decoder', 'embed_positions', 'embedding'),
 ('model', 'decoder', 'embed_tokens', 'embedding'),
 ('model', 'decoder', 'final_ln', 'bias'),
 ('model', 'decoder', 'layernorm_embedding', 'bias'),
 ('model', 'decoder', 'layernorm_embedding', 'scale')}

In [13]:
params_reinit = model.init_weights(model.key, model.input_shape, params=params)

### Verification

The structure should be the same as the original `params` dict. Re-initialized layers should have different parameters, but existing layers should be the same.

In [14]:
jax.tree_map(lambda x: x.shape, params_reinit)

FrozenDict({
    lm_head: {
        kernel: (1024, 16385),
    },
    model: {
        decoder: {
            embed_positions: {
                embedding: (256, 1024),
            },
            embed_tokens: {
                embedding: (16385, 1024),
            },
            final_ln: {
                bias: (1024,),
            },
            layernorm_embedding: {
                bias: (1024,),
                scale: (1024,),
            },
            layers: {
                FlaxBartDecoderLayers: {
                    FlaxBartAttention_0: {
                        k_proj: {
                            kernel: (12, 1024, 1024),
                        },
                        out_proj: {
                            kernel: (12, 1024, 1024),
                        },
                        q_proj: {
                            kernel: (12, 1024, 1024),
                        },
                        v_proj: {
                            kernel: (12, 1024, 1024),
       

In [15]:
params_reinit['model']['decoder']['embed_positions']

FrozenDict({
    embedding: DeviceArray([[ 0.00582082, -0.04113895,  0.00918633, ..., -0.00530822,
                   0.01297319,  0.02720674],
                 [ 0.03540739,  0.03676804, -0.02924041, ...,  0.00163185,
                  -0.01938273, -0.02105987],
                 [ 0.00478452, -0.03438002, -0.0024974 , ..., -0.03892584,
                   0.01721252,  0.02605445],
                 ...,
                 [ 0.02495495,  0.00559381, -0.01588043, ...,  0.01393714,
                  -0.01824111, -0.02007291],
                 [ 0.00983252, -0.00180564, -0.01686333, ..., -0.01001718,
                   0.01886345, -0.00393983],
                 [-0.03589988, -0.00455565,  0.00076276, ..., -0.02145007,
                  -0.00180798, -0.0133148 ]], dtype=float32),
})

In [16]:
embedding_new = params_reinit['model']['decoder']['embed_positions']['embedding']
embedding_new.min(), embedding_new.max()

(DeviceArray(-0.09320386, dtype=float32),
 DeviceArray(0.08769083, dtype=float32))

In [17]:
params_original['model']['decoder']['embed_positions']

{'embedding': DeviceArray([[ 0.03459017, -0.0065838 , -0.11748601, ..., -0.01451578,
               -0.03927238, -0.00266367],
              [-0.03116009,  0.00438436,  0.02691377, ..., -0.02886203,
               -0.01095741, -0.02649871],
              [-0.03568491, -0.0086962 ,  0.01851564, ..., -0.04736514,
                0.05310551, -0.01648099],
              ...,
              [-0.02454913,  0.03746822, -0.02269235, ...,  0.03377315,
                0.003004  ,  0.04975331],
              [-0.05145862,  0.04472217,  0.11103845, ...,  0.04581303,
                0.02850476,  0.00554514],
              [-0.01037806,  0.00281054, -0.0485299 , ..., -0.03325456,
               -0.0058979 ,  0.01733843]], dtype=float32)}

In [18]:
embedding_original = params_original['model']['decoder']['embed_positions']['embedding']
embedding_original.min(), embedding_new.max()

(DeviceArray(-0.25866088, dtype=float32),
 DeviceArray(0.08769083, dtype=float32))

In [19]:
assert(
    jnp.allclose(embedding_new, embedding_original).item() == False
)

In [20]:
lm_head_original = params_original['lm_head']['kernel']
lm_head_reinit = params_reinit['lm_head']['kernel']
assert(
    jnp.allclose(lm_head_reinit, lm_head_original).item() == False
)

In [21]:
assert(
    jnp.allclose(
        params_reinit['model']['encoder']['layers']['FlaxBartEncoderLayers']['FlaxBartAttention_0']['k_proj']['kernel'],
        params_original['model']['encoder']['layers']['FlaxBartEncoderLayers']['FlaxBartAttention_0']['k_proj']['kernel']
    ).item()
)

## Save checkpoint for retrain

Finally, we save the resulting model to retrain those layers.

In [22]:
checkpoint_dir = "mini-reinit"

In [23]:
model.save_pretrained(checkpoint_dir, params=params_reinit)

tcmalloc: large alloc 3367796736 bytes == 0x5633f235a000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571


### Upload checkpoint to W&B

In [24]:
import wandb
from pathlib import Path

In [25]:
wandb.init(
    entity = 'dalle-mini',
    project = 'dalle-mini',
    job_type = 'Seq2Seq',
)

[34m[1mwandb[0m: Currently logged in as: [33mpcuenq[0m ([33mdalle-mini[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [26]:
artifact = wandb.Artifact(
    name=f"model-{wandb.run.id}",
    type="DalleBart_model",
    metadata={"embeddings": "reset"},
)

for filename in ["config.json", "flax_model.msgpack"]:
    artifact.add_file(f"{Path(checkpoint_dir) / filename}")

In [27]:
wandb.run.log_artifact(artifact)

<wandb.sdk.wandb_artifacts.Artifact at 0x7f95984c3fd0>

In [28]:
wandb.finish()

VBox(children=(Label(value='1670.207 MB of 1670.207 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.â€¦

----