netket.utils.struct.field

Contents

netket.utils.struct.field#

netket.utils.struct.field(pytree_node=True, pytree_ignore=False, serialize=None, serialize_name=None, cache=False, sharded=False, **kwargs)[source]#

Mark a field of a dataclass or PyTree to be:

Parameters:
  • pytree_node (bool) – a leaf node in the pytree representation of this dataclass. If False this must be hashable.

  • pytree_ignore (bool) – If True this field is ignored by the pytree metadata and will be excluded from the pytree flattening/unflattening process. This means the field will not appear in the flattened representation when calling jax.tree_util.tree_flatten() or similar pytree operations. This is useful for caches, temporary data, or other fields that should not be passed forward during pytree transformations. When True, pytree_node must be False.

  • serialize (bool | None) – If True the node is included in the serialization. In general you should not specify this. (Defaults to value of pytree_node).

  • serialize_name (str | None) – If specified, it’s the name under which this attribute is serialized. This can be used to change the runtime attribute name, but maintain some other name in the serialisation format.

  • cache (bool) – If True this node is a cache and will be reset every time fields are modified.

  • sharded (bool | ShardedFieldSpec) – a boolan or specification object specifying whether this entry is sharded. Defaults to False. If True, a JAX-compatible sharding along axis 0 is assumed.