netket.sampler.Sampler#
- class netket.sampler.Sampler[source]#
Bases:
Pytree
Abstract base class for all samplers.
It contains the fields that all of them should possess, defining the common API. Note that fields marked with pytree_node=False are treated as static arguments when jitting.
Subclasses should be NetKet dataclasses and they should define the _init_state, _reset and _sample_chain methods which only accept positional arguments. See the respective methodβs definition for its signature.
Notice that those methods are different from the API-entry point without the leading underscore in order to allow us to share some pre-processing code between samplers and simplify the definition of a new sampler.
- Inheritance
- __init__(hilbert, *, machine_pow=2, dtype=<class 'float'>)[source]#
Construct a Monte Carlo sampler.
- Parameters:
- Attributes
- is_exact#
Returns True if the sampler is exact.
The sampler is exact if all the samples are exactly distributed according to the chosen power of the variational state, and there is no correlation among them.
- n_batches#
The batch size of the configuration $sigma$ used by this sampler on this jax process.
This is used to determine the shape of the batches generated in a single process. This is needed because when using JAX sharding, we must declare the full shape on every jax process, therefore this returns
n_chains
.Usage of this flag is required to support JAX sharding.
Samplers may override this to have a larger batch size, for example to propagate multiple replicas (in the case of parallel tempering).
- n_chains#
The total number of independent chains.
This is at least equal to the total number of jax devices that are used to distribute the calculation.
- n_chains_per_rank#
The total number of independent chains per jax device.
If you are not distributing the calculation among jax devices, this is equal to
n_chains
.In general this is equal to
import jax sampler.n_chains // jax.device_count()
-
hilbert:
AbstractHilbert
# The Hilbert space to sample.
- Methods
- init_state(machine, parameters, seed=None)[source]#
Creates the structure holding the state of the sampler.
If you want reproducible samples, you should specify seed, otherwise the state will be initialised randomly.
If running across several JAX processes, all sampler_state`s are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every JAX process, then generating `n_process seeds starting from the reduced one, and every process is initialized with one of those seeds.
The resulting state is guaranteed to be a frozen Python dataclass (in particular, a Flax dataclass), and it can be serialized using Flax serialization methods.
- Parameters:
machine (
Union
[Module
,Callable
[[Any
,Union
[ndarray
,Array
]],Union
[ndarray
,Array
]]]) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jax.Array
.parameters (
Any
) β The PyTree of parameters of the model.seed (
Union
[int
,Any
,None
]) β An optional seed or jax PRNGKey. If not specified, a random seed will be used.
- Return type:
- Returns:
The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use.
- log_pdf(model)[source]#
Returns a closure with the log-pdf function encoded by this sampler.
- Parameters:
model (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jax.Array
.- Return type:
- Returns:
The log-probability density function.
Note
The result is returned as a HashablePartial so that the closure does not trigger recompilation.
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- reset(machine, parameters, state=None)[source]#
Resets the state of the sampler. To be used every time the parameters are changed.
- Parameters:
machine (
Union
[Module
,Callable
[[Any
,Union
[ndarray
,Array
]],Union
[ndarray
,Array
]]]) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jax.Array
.parameters (
Any
) β The PyTree of parameters of the model.state (
SamplerState
|None
) β The current state of the sampler. If not specified, it will be constructed by callingsampler.init_state(machine, parameters)
with a random seed.
- Return type:
- Returns:
A valid sampler state.
- sample(machine, parameters, *, state=None, chain_length=1, return_log_probabilities=False)[source]#
Samples chain_length batches of samples along the chains.
- Parameters:
machine (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jax.Array
.parameters (
Any
) β The PyTree of parameters of the model.state (
SamplerState
|None
) β The current state of the sampler. If not specified, then initialize and reset it.chain_length (
int
) β The length of the chains (default = 1).return_log_probabilities (
bool
) β If True, the log-probabilities are also returned, which is sometimes useful to avoid re-evaluating the log-pdf when doing importance sampling. Defaults to False.
- Return type:
tuple
[Array
,SamplerState
] |tuple
[tuple
[Array
,Array
],SamplerState
]- Returns:
Returns a tuple of βsamplesβ and βstateβ. If return_log_probabilities is False, the samples are just the 3-rank array of samples. If return_log_probabilities is True, the samples are a tuple of the 3-rank array of samples and the 2-rank array of un-normalized log-probabilities corresponding to each sample.
- samples(machine, parameters, *, state=None, chain_length=1)[source]#
Returns a generator sampling chain_length batches of samples along the chains.
- Parameters:
machine (
Union
[Module
,Callable
[[Any
,Union
[ndarray
,Array
]],Union
[ndarray
,Array
]]]) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jax.Array
.parameters (
Any
) β The PyTree of parameters of the model.state (
SamplerState
|None
) β The current state of the sampler. If not specified, then initialize and reset it.chain_length (
int
) β The length of the chains (default = 1).
- Return type: