Kernel initializers can be set to and saved as None. In those cases, we should not retrieve them from networks.KERNEL_INITIALIZER.
This issue is probably similar to what's encountered here 770e837.
For example, mean_kernel_init_fn could be None.
|
mean_kernel_init_fn: networks.Initializer | None = None, |
When loaded from the saved checkpoint, we will try to access networks.KERNEL_INITIALIZER[None], resulting in an error:
Traceback (most recent call last):
File "/Users/linfeng/workspace/brax/scripts/reproduce.py", line 47, in <module>
inference_fn = ppo_checkpoint.load_policy(
checkpoints[0],
network_factory=network_factory,
deterministic=True,
)
File "/Users/linfeng/workspace/brax/brax/training/agents/ppo/checkpoint.py", line 83, in load_policy
config = load_config(path)
File "/Users/linfeng/workspace/brax/brax/training/agents/ppo/checkpoint.py", line 71, in load_config
return checkpoint.load_config(config_path)
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^
File "/Users/linfeng/workspace/brax/brax/training/checkpoint.py", line 229, in load_config
networks.KERNEL_INITIALIZER[init_fn_name_]
~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
KeyError: None
A reproducible example, with python==3.13.0, brax==0.14.1 and playground==0.1.0, run on macOS 15.5:
#!/usr/bin/env python3
import glob
import functools
from pathlib import Path
from brax.training.agents.ppo import train
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import checkpoint as ppo_checkpoint
from mujoco_playground import wrapper
from mujoco_playground.config import manipulation_params
from mujoco_playground import registry
env_name = "AeroCubeRotateZAxis"
env_cfg = registry.get_default_config(env_name)
env_cfg.episode_length = 10
env = registry.load(env_name, env_cfg)
ppo_params = manipulation_params.brax_ppo_config(env_name)
ppo_params.num_timesteps = 1
ppo_params.num_evals = 1
ppo_params.num_envs = 1
ppo_params.num_minibatches = 1
ppo_params.num_updates_per_batch = 1
ppo_params.batch_size = 1
ppo_training_params = dict(ppo_params)
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_obs_key="privileged_state",
value_obs_key="privileged_state",
policy_hidden_layer_sizes=(128,),
value_hidden_layer_sizes=(128,),
)
del ppo_training_params["network_factory"]
ppo_training_params["network_factory"] = network_factory
checkpoints_path = Path("./checkpoints").expanduser().resolve().as_posix()
make_inference_fn, params, metrics = train.train(
environment=env,
**dict(ppo_training_params),
save_checkpoint_path=checkpoints_path,
wrap_env_fn=wrapper.wrap_for_brax_training,
)
checkpoints = glob.glob(f"{checkpoints_path}/*/")
inference_fn = ppo_checkpoint.load_policy(
checkpoints[0],
network_factory=network_factory,
deterministic=True,
)
Kernel initializers can be set to and saved as
None. In those cases, we should not retrieve them fromnetworks.KERNEL_INITIALIZER.This issue is probably similar to what's encountered here 770e837.
For example,
mean_kernel_init_fncould beNone.brax/brax/training/agents/ppo/networks.py
Line 97 in 4597b9c
When loaded from the saved checkpoint, we will try to access
networks.KERNEL_INITIALIZER[None], resulting in an error:A reproducible example, with
python==3.13.0,brax==0.14.1andplayground==0.1.0, run on macOS 15.5: