Due to changes in jax tree api, TreeClass no longer treated as named tuple when indexing using AtIndexer/.at.
import pytreeclass as tc
class Tree(tc.TreeClass):
def __init__(self, a, b):
self.a = a
self.b = b
tree = Tree(1, 2)
print(tree.at["a"].get()) # 1
print(tree.at[0].get()) # 1 -> no longer valid-
Supports multibackend:
numpy+optreeviaexport PYTREECLASS_BACKEND=numpy(lightweight option)jaxviaexport PYTREECLASS_BACKEND=jax- The default -torch+optreeviaexport PYTREECLASS_BACKEND=torch- no array +
optreeviaexport PYTREECLASS_BACKEND=default
-
drop
callbackoption in parallel options inis_parallel -
Add parallel processing via
is_parallelto.{get,set} -
register_excluded_typetoautoinitto exclude certain types to be infielddefaults.
- change
threads_countinapplyparallel kwargs tomax_workers
- Add parallel mapping option in
AtIndexer. This enables myriad of tasks, like reading a pytree of image file names.
# benchmarking serial vs sequential image read
# on mac m1 cpu with image of size 512x512x3
import pytreeclass as tc
from matplotlib.pyplot import imread
paths = ["lenna.png"] * 10
indexer = tc.AtIndexer(paths)
%timeit indexer[...].apply(imread,is_parallel=True) # parallel
# 24.9 ms ± 938 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit indexer[...].apply(imread) # not parallel
# # 84.8 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)-
To simplify the API the following will be removed:
tree_repr_with_tracetree_map_with_tracetree_flatten_with_tracetree_leaves_with_trace
A variant of these will be included in the common recipes.
- Add
on_getattrinfieldto apply function on__getattr__
- Rename
callbacksinfieldtoon_setattrto matchattrsand better reflect its functionality.
These changes enable:
-
stricter data validation on instance values, as in the following example:
Details
on_setattrensure the value is of certain type (e.g.integer) during initialization, andon_getattr, ensure the value is of certain type (e.g. integer) whenever its accessed.import pytreeclass as tc import jax def assert_int(x): assert isinstance(x, int), "must be an int" return x @tc.autoinit class Tree(tc.TreeClass): a: int = tc.field(on_getattr=[assert_int], on_setattr=[assert_int]) def __call__(self, x): # enusre `a` is an int before using it in computation by calling `assert_int` a: int = self.a return a + x tree = Tree(a=1) print(tree(1.0)) # 2.0 tree = jax.tree_map(lambda x: x + 0.0, tree) # make `a` a float tree(1.0) # AssertionError: must be an int
-
Frozen field without using
tree_mask/tree_unmaskDetails
The following shows a pattern where the value is frozen on
__setattr__and unfrozen whenever accessed, this ensures thatjaxtransformation does not see the value. the following example showcase this functionalityimport pytreeclass as tc import jax @tc.autoinit class Tree(tc.TreeClass): frozen_a : int = tc.field(on_getattr=[tc.unfreeze], on_setattr=[tc.freeze]) def __call__(self, x): return self.frozen_a + x tree = Tree(frozen_a=1) # 1 is non-jaxtype # can be used in jax transformations @jax.jit def f(tree, x): return tree(x) f(tree, 1.0) # 2.0 grads = jax.grad(f)(tree, 1.0) # Tree(frozen_a=#1)
Compared with other libraies that implements
static_field, this pattern has lower overhead and does not altertree_flatten/tree_unflattenmethods of the tree. -
Easier way to create a buffer (non-trainable array)
Details
Just use
jax.lax.stop_gradientinon_getattrimport pytreeclass as tc import jax import jax.numpy as jnp def assert_array(x): assert isinstance(x, jax.Array) return x @tc.autoinit class Tree(tc.TreeClass): buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array]) def __call__(self, x): return self.buffer**x tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0])) tree(2.0) # Array([1., 4., 9.], dtype=float32) @jax.jit def f(tree, x): return jnp.sum(tree(x)) f(tree, 1.0) # Array([1., 2., 3.], dtype=float32) print(jax.grad(f)(tree, 1.0)) # Tree(buffer=[0. 0. 0.])
- Remove
.atas an alias for__getitem__when specifying a path entry for where inAtIndexer. This leads to less verbose style.
Example:
>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> indexer = tc.AtIndexer(tree)
>>> # Before:
>>> # style 1 (with at):
>>> indexer.at["level1_0"].at["level2_0", "level2_1"].get()
{'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None}
>>> # style 2 (no at):
>>> indexer["level1_0"]["level2_0", "level2_1"].get()
>>> # After
>>> # only style 2 is valid
>>> indexer["level1_0"]["level2_0", "level2_1"].get()- tree = indexer.at["level1_0"].at["level2_0", "level2_1"].get()
+ tree = indexer["level1_0"]["level2_0", "level2_1"].get()For TreeClass
at is specified once for each change
@tc.autoinit
class Tree(tc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
- .at["b"].at[0].set(10.0)\
+ .at["b"][0].set(10.0)\
.at[mask].set(100.0)- using
tree_{repr,str}with an object containing cyclic references will raiseRecursionErrorinstead of displaying cyclicref.
-
Allow nested mutations using
.at[method](*args, **kwargs). After the change, inner methods can mutate copied new instances at any level not just the top level. a motivation for this is to experiment with lazy initialization scheme, where inner layers need to mutate their inner state. see the example below forflax-like lazy initialization as descriped hereDetails
import pytreeclass as tc import jax.random as jr from typing import Any import jax import jax.numpy as jnp from typing import Callable, TypeVar T = TypeVar("T") @tc.autoinit class LazyLinear(tc.TreeClass): outdim: int weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal() bias_init: Callable[..., T] = jax.nn.initializers.zeros def param(self, name: str, init_func: Callable[..., T], *args) -> T: if name not in vars(self): setattr(self, name, init_func(*args)) return vars(self)[name] def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)): w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim)) y = x @ w if self.bias_init is not None: b = self.param("bias", self.bias_init, key, (self.outdim,)) return y + b return y @tc.autoinit class StackedLinear(tc.TreeClass): l1: LazyLinear = LazyLinear(outdim=10) l2: LazyLinear = LazyLinear(outdim=1) def call(self, x: jax.Array): return self.l2(jax.nn.relu(self.l1(x))) lazy_layer = StackedLinear() print(repr(lazy_layer)) # StackedLinear( # l1=LazyLinear( # outdim=10, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype) # ), # l2=LazyLinear( # outdim=1, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype) # ) # ) _, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5))) materialized_layer # StackedLinear( # l1=LazyLinear( # outdim=10, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype), # weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]), # bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00]) # ), # l2=LazyLinear( # outdim=1, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype), # weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]), # bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00]) # ) # ) materialized_layer(jnp.ones((1, 5))) # Array([[0.16712935]], dtype=float32)
- fix
__init_subclass__. not accepting arguments. this bug is introduced sincev0.5
Alternatives
Use:
- Preferably decorate with
pytreeclass.autoinitwithpytreeclass.fieldas field specifier. aspytreeclass.fieldhas more features (e.g.callbacks, multiple argument kind selection) and the init generation is cached compared todataclasses. - decorate with
dataclasses.dataclasswithdataclasses.fieldas field specifier. however :- Must set
fronzen=Falsebecause the__setattr__,__delattr__is handled byTreeClass - Optionally
repr=Falseto be handled byTreeClass - Optionally
eq=hash=Falseas it is handled byTreeClass
- Must set
import jax.tree_util as jtu
import pytreeclass as tc
import dataclasses as dc
class Tree(tc.TreeClass):
a: int = 1
jtu.tree_leaves(Tree())
# [1] |
Equivalent behavior when decorating with either:
import jax.tree_util as jtu
import pytreeclass as tc
@tc.autoinit
class Tree(tc.TreeClass):
a: int = 1
jtu.tree_leaves(Tree())
# [1] |
This change aims to fix the ambiguity of using the dataclass mental model in the following siutations:
-
subclassing. previously, using
TreeClassas a base class is equivalent to decorating the class withdataclasses.dataclass, however this is a bit challenging to understand as demonstrated in the next example:import pytreeclass as tc import dataclasses as dc class A(tc.TreeClass): def ___init__(self, a:int): self.a = a class B(A): ...
When instantiating
B(a=...), an error will be raised, because usingTreeClassis equivalent of decorating all classes with@dataclass, which synthesize the__init__method based on the fields. Since no fields (e.g. type hinted values) then the synthesized__init__method .The previous code is equivalent to this code.
@dc.dataclass class A: def __init__(self, a:int): self.a = a @dc.dataclass class B: ...
-
dataclass_transformdoes not play nicely with user created__init__see 1, 2
instead decorate the class with pytreeclass.leafwise.
-
User-provided
re.Patternis used to match keys with regex pattern instead of usingRegexKeyDetails
Example:
import pytreeclass as tc import re tree = {"l1":1, "l2":2, "b":3} tree = tc.AtIndexer(tree) tree.at[re.compile("l.*")].get() # {'b': None, 'l1': 1, 'l2': 2}
RegexKeyis deprecated. userecompiled patterns instead.tree_indentis deprecated. usetree_diagram(tree).replace(...)to replace the edges characters with spaces.
-
Add
tree_mask,tree_unmaskto freeze/unfreeze tree leaves based on a callable/boolean pytree mask. defaults to masking non-inexact types by frozen wrapper.Details
Example: Pass non-
jaxtypes throughjaxtransformation without error.# pass non-differentiable values to `jax.grad` import pytreeclass as tc import jax @jax.grad def square(tree): tree = tc.tree_unmask(tree) return tree[0]**2 tree = (1., 2) # contains a non-differentiable node square(tc.tree_mask(tree)) # (Array(2., dtype=float32, weak_type=True), #2)
-
Support extending match keys by adding abstract base class
BaseKey. check docstring for example -
Support multi-index by any acceptable form. e.g. boolean pytree, key, int, or
BaseKeyinstanceDetails
Example:
import pytreeclass as tc tree = {"l1":1, "l2":2, "b":3} tree = tc.AtIndexer(tree) tree.at["l1","l2"].get() # {'b': None, 'l1': 1, 'l2': 2}
-
add
scantoAtIndexerto carry a state while applying a function.Details
Example:
import pytreeclass as tc def scan_func(leaf, state): # increase the state by 1 for each function call return leaf**2, state+1 tree = {"l1": 1, "l2": 2, "b": 3} tree = tc.AtIndexer(tree) tree, state = tree.at["l1", "l2"].scan(scan_func, 0) state # 2 tree # {'b': 3, 'l1': 1, 'l2': 4}
-
tree_summaryimprovements.- Add size column to
tree_summary. - add
def_countto dispatch count rule for type. - add
def_sizeto dispatch size rule for type. - add
def_typeto dispatch type display.
Details
Example:
import pytreeclass as tc import jax.numpy as jnp x = jnp.ones((5, 5)) print(tc.tree_summary([1, 2, 3, x])) # ┌────┬────────┬─────┬───────┐ # │Name│Type │Count│Size │ # ├────┼────────┼─────┼───────┤ # │[0] │int │1 │ │ # ├────┼────────┼─────┼───────┤ # │[1] │int │1 │ │ # ├────┼────────┼─────┼───────┤ # │[2] │int │1 │ │ # ├────┼────────┼─────┼───────┤ # │[3] │f32[5,5]│25 │100.00B│ # ├────┼────────┼─────┼───────┤ # │Σ │list │28 │100.00B│ # └────┴────────┴─────┴───────┘ # make list display its number of elements # in the type row @tc.tree_summary.def_type(list) def _(_: list) -> str: return f"List[{len(_)}]" print(tc.tree_summary([1, 2, 3, x])) # ┌────┬────────┬─────┬───────┐ # │Name│Type │Count│Size │ # ├────┼────────┼─────┼───────┤ # │[0] │int │1 │ │ # ├────┼────────┼─────┼───────┤ # │[1] │int │1 │ │ # ├────┼────────┼─────┼───────┤ # │[2] │int │1 │ │ # ├────┼────────┼─────┼───────┤ # │[3] │f32[5,5]│25 │100.00B│ # ├────┼────────┼─────┼───────┤ # │Σ │List[4] │28 │100.00B│ # └────┴────────┴─────┴───────┘
- Add size column to
-
Export pytrees to dot language using
tree_graphDetails
# define custom style for a node by dispatching on the value # the defined function should return a dict of attributes # that will be passed to graphviz. import pytreeclass as tc tree = [1, 2, dict(a=3)] @tc.tree_graph.def_nodestyle(list) def _(_) -> dict[str, str]: return dict(shape="circle", style="filled", fillcolor="lightblue") dot_graph = graphviz.Source(tc.tree_graph(tree)) dot_graph
-
Add variable position arguments and variable keyword arguments to
tc.fieldkindDetails
import pytreeclass as tc class Tree(tc.TreeClass): a: int = tc.field(kind="VAR_POS") b: int = tc.field(kind="POS_ONLY") c: int = tc.field(kind="VAR_KW") d: int e: int = tc.field(kind="KW_ONLY") Tree.__init__ # <function __main__.Tree.__init__(self, b: int, /, d: int, *a: int, e: int, **c: int) -> None>
This release introduces lots of functools.singledispatch usage, to enable the greater customization.
{freeze,unfreeze,is_nondiff}.def_typeto define how tofreezea type, how to unfreeze it and whether it is considred nondiff or not. these rules are used by these functions andtree_mask/tree_unmask.tree_graph.def_nodestyle,tree_summary.def_{count,type,size}for pretty printing customizationBaseKey.def_aliasto define type alias usage insideAtIndexer/.at- Internally, most of the pretty printing is using dispatching to define repr/str rules for each instance type.
