Skip to content

checkpoint.load_config should ignore None kernel initializer #664

@fengzileee

Description

@fengzileee

Kernel initializers can be set to and saved as None. In those cases, we should not retrieve them from networks.KERNEL_INITIALIZER.
This issue is probably similar to what's encountered here 770e837.

For example, mean_kernel_init_fn could be None.

mean_kernel_init_fn: networks.Initializer | None = None,

When loaded from the saved checkpoint, we will try to access networks.KERNEL_INITIALIZER[None], resulting in an error:

Traceback (most recent call last):
  File "/Users/linfeng/workspace/brax/scripts/reproduce.py", line 47, in <module>
    inference_fn = ppo_checkpoint.load_policy(
        checkpoints[0],
        network_factory=network_factory,
        deterministic=True,
    )
  File "/Users/linfeng/workspace/brax/brax/training/agents/ppo/checkpoint.py", line 83, in load_policy
    config = load_config(path)
  File "/Users/linfeng/workspace/brax/brax/training/agents/ppo/checkpoint.py", line 71, in load_config
    return checkpoint.load_config(config_path)
           ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^
  File "/Users/linfeng/workspace/brax/brax/training/checkpoint.py", line 229, in load_config
    networks.KERNEL_INITIALIZER[init_fn_name_]
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
KeyError: None

A reproducible example, with python==3.13.0, brax==0.14.1 and playground==0.1.0, run on macOS 15.5:

#!/usr/bin/env python3
import glob
import functools
from pathlib import Path
from brax.training.agents.ppo import train
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import checkpoint as ppo_checkpoint
from mujoco_playground import wrapper

from mujoco_playground.config import manipulation_params
from mujoco_playground import registry

env_name = "AeroCubeRotateZAxis"
env_cfg = registry.get_default_config(env_name)
env_cfg.episode_length = 10
env = registry.load(env_name, env_cfg)

ppo_params = manipulation_params.brax_ppo_config(env_name)
ppo_params.num_timesteps = 1
ppo_params.num_evals = 1
ppo_params.num_envs = 1
ppo_params.num_minibatches = 1
ppo_params.num_updates_per_batch = 1
ppo_params.batch_size = 1
ppo_training_params = dict(ppo_params)

network_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_obs_key="privileged_state",
    value_obs_key="privileged_state",
    policy_hidden_layer_sizes=(128,),
    value_hidden_layer_sizes=(128,),
)
del ppo_training_params["network_factory"]
ppo_training_params["network_factory"] = network_factory

checkpoints_path = Path("./checkpoints").expanduser().resolve().as_posix()
make_inference_fn, params, metrics = train.train(
    environment=env,
    **dict(ppo_training_params),
    save_checkpoint_path=checkpoints_path,
    wrap_env_fn=wrapper.wrap_for_brax_training,
)

checkpoints = glob.glob(f"{checkpoints_path}/*/")

inference_fn = ppo_checkpoint.load_policy(
    checkpoints[0],
    network_factory=network_factory,
    deterministic=True,
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions