netket.jax.tree_axpy

Contents

netket.jax.tree_axpy#

netket.jax.tree_axpy(a, x, y)[source]#

Compute a * x + y

Parameters:
  • a (Any) – scalar or pytree

  • x (Any) – pytrees with the same treedef

  • y (Any) – pytrees with the same treedef

Return type:

Any

Returns:

The sum of the respective leaves of the two pytrees x and y where the leaves of x are first scaled with a.