netket.jax.expect

Contents

netket.jax.expect#

netket.jax.expect(log_pdf, expected_fun, pars, σ, *expected_fun_args, n_chains=None, chunk_size=None, in_axes=None)[source]#

Computes the expectation value over a log-pdf, equivalent to

\[\langle f \rangle = \mathbb{E}_{\sigma \sim p(x)}[f(\sigma)] = \sum_{\mathbf{x}} p(\mathbf{x}) f(\mathbf{x})\]

where the evaluation of the expectation value is approximated using the sample average, with samples \(\sigma\) that are assumed to be drawn from the probability distribution \(p(x)\).

\[\langle f \rangle \approx \frac{1}{N} \sum_{i=1}^{N} f(\sigma_i)\]

This function ensures that the backward pass is computed correctly, by first differentiating the first equation above, and then by approximating the expectation values again using the sample average. The resulting backward gradient is

\[\nabla \langle f \rangle = \mathbb{E}_{\sigma \sim p(x)}[(\nabla \log p(\sigma)) f(\sigma) + \nabla f(\sigma)]\]

where again, the expectation values are comptued using the sample average.

Example

Compute the energy gradient using nk.jax.expect.

>>> import netket as nk
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> hi = nk.hilbert.Spin(s=0.5, N=20)
>>> graph = nk.graph.Chain(length=20)
>>> H = nk.operator.IsingJax(hi, graph, h=1.0)
>>> vstate = nk.vqs.MCState(sampler=nk.sampler.MetropolisLocal(hi, n_chains_per_rank=16), model=nk.models.RBM(alpha=1, param_dtype=complex), n_samples=100000)
>>>
>>> afun = vstate._apply_fun
>>> pars = vstate.parameters
>>> model_state = vstate.model_state
>>> log_pdf = lambda params, σ: 2 * afun({"params": params, **model_state}, σ).real
>>>
>>> σ = vstate.samples
>>> σ = σ.reshape(-1, σ.shape[-1])
>>>
>>> # The function that we want to differentiate wrt pars and σ
>>> # Note that we do not want to compute the gradient wrt model_state, so
>>> # we capture it inside of this function.
>>> def expect(pars, σ):
...
...     # The log probability distribution we have generated samples σ from.
...     def log_pdf(pars, σ):
...         W = {"params": pars, **model_state}
...         return 2 * afun(W, σ).real
...
...     def expected_fun(pars, σ):
...         W = {"params": pars, **model_state}
...         # Get connected samples
...         σp, mels = H.get_conn_padded(σ)
...         logpsi_σ = afun(W, σ)
...         logpsi_σp = afun(W, σp)
...         logHpsi_σ = jax.scipy.special.logsumexp(logpsi_σp, b=mels, axis=1)
...         return jnp.exp(logHpsi_σ - logpsi_σ)
...     return nk.jax.expect(log_pdf, expected_fun, pars, σ)[0]
>>>
>>> E, E_vjp_fun = nk.jax.vjp(expect, pars, σ)
>>> grad = E_vjp_fun(jnp.ones_like(E))[0]
Parameters:
  • log_pdf (Callable[[Any, Array], Array]) – The log-pdf function from which the samples are drawn. This should output real values, and have a signature log_pdf(pars, σ) -> jnp.ndarray.

  • expected_fun (Callable[[Any, Array], Array]) – The function to compute the expectation value of. This should have a signature expected_fun(pars, σ, *expected_fun_args) -> jnp.ndarray.

  • pars (Any) – The parameters of the model.

  • σ (Array) – The samples to compute the expectation value over.

  • expected_fun_args – Additional arguments to pass to the expected_fun function (will be differentiated; to avoid differentiation, capture them as constants inside of the expected_fun).

  • n_chains (int | None) – The number of chains to use in the computation. If None, the number of chains is inferred from the shape of the input.

  • chunk_size (int | None) – The size of the chunks to use in the computation. If None, no chunking is used.

  • in_axes (tuple[int | None, ...] | None) – The axes along which to perform the chunking. If none, only the samples are chunked, otherwise this must be the sharding declaration of the samples and the additional arguments to the expected_fun function (must have length equal to the number of expected_fun_args + 2).

Return type:

tuple[Array, Stats]

Returns:

A tuple where the first element is the scalar value containing the expectation value, and the second element is a netket.stats.Stats object containing the statistics (including the mean) of the expectation value.