netket.experimental.driver.VMC_SR#
- class netket.experimental.driver.VMC_SR[source]#
Bases:
AbstractVariationalDriver
Energy minimization using Variational Monte Carlo (VMC) and Stochastic Reconfiguration/Natural Gradient Descent. This driver is mathematically equivalent to the standard
netket.driver.VMC
with the preconditionernetket.optimizer.SR(solver=netket.optimizer.solvers.cholesky)
, but can easily switch between the standard and the kernel/minSR formulation of Natural Gradient Descent.The standard formulation computes the updates as:
\[\delta \theta = \tau (X^TX + \lambda \mathbb{I}_{N_P})^{-1} X^T E^{loc},\]where \(X \in R^{N_s \times N_p}\) is the Jacobian of the log-wavefunction, with \(N_p\) the number of parameters and \(N_s\) the number of samples. The vector \(E^{loc}\) is the centered local estimator for the local energies.
The kernel/minSR formulation computes the updates as:
\[\delta \theta = \tau X^T(XX^T + \lambda \mathbb{I}_{2N_s})^{-1} E^{loc},\]The regularization parameter \(\lambda\) is the diag_shift parameter of the driver, which can be a scalar or a schedule. The updates are then applied to the parameters using the optimizer which in general should be optax.sgd.
Matrix Inversion#
The matrix inversion of both methods is performed using a linear solver, which can be specified by the user. This must be a function, the
linear_solver_fun
argument, which has the following signature:linear_solver_fn(A: Matrix, b: vector) -> tuple[jax.Array[vector], dict]
Where the vector is the solution and the dictionary may contain additional information about the solver or be None. The standard solver is based on the Cholesky decomposition
cholesky()
, but any other solver from JAX, netket solvers or a custom-written one can be used.Natural Gradient Descent#
Stochastic Reconfiguration is equivalent to the Natural Gradient Descent method introduced by Amari 1998 in the context of neural network training, assuming that the natural metric of the space of wave-functions is the Fubini-Study metric. This was first studied by Stokes et Al 2019 and called quantum Natural Gradient Descent.
While stochastic reconfiguration has been heavily studied in the context of VMC, there is a vast literature in the Machine Learning community on the use of NGD, and tuning carefully the diag shift and the learning rate.
A very good introduction to the mathematics of Information Geometry and NGD is found in Bai et Al and further studied in Shrestha et Al 2022. From the Physicist point of view, a good discussion on the choice of the metric function (QGT vs Fisher Matrix) is found in Stokes et Al 2022 (section 4 in particoular). For a comprehensive review of the method, we suggest the review by Martens 2014.
Momentum / SPRING#
When momentum is used, this driver implements the SPRING optimizer in Goldshlager et Al. (2024) to accumulate previous updates for better approximation of the exact SR with no significant performance penalty.
momentum μ is a number between [0,1] that specifies the damping factor of the previous updates and works somewhat similarly to the beta parameter of ADAM. The difference is that rather than simply adding the damped previous update to the new update, SPRING uses the damped previous update to fill in the components of the SR direction that are not sampled by the current batch of walkers, resulting in a more accurate and less noisy estimate. Since SPRING only uses the previous update to fill in directions that are orthogonal to the current one, the maximum amplification of the step size in SPRING is \(A(\mu) = 1/\sqrt{1-μ^2}\) rather than \(1/(1-μ)\).
Thus the amplification is at most a factor of \(A(0.9)=2.3\) or \(A(0.99)=7.1\). ** Values that empirically work are around 0.8. **
Some progress has been made on theoretically analyzing this parameter, in particular Section 3 of Epperly et Al. demonstrates (albeit in a significantly simplified linear least-squares setting) that SPRING can be interpreted as iteratively estimating a regularized SR direction, with the amount of regularization proportional to the value of 1-momentum. Additional insights regarding the behavior of some SPRING-like algorithms, albeit still in the linear least-squares setting, are presented in Goldshlager et Al. (2025) .
Implementation details#
The kernel-trick/NTK based implementation can run with both a direct calculation of the jacobian (on_the_fly=False) or with a lazy evaluation of the NTK (on_the_fly=True). The latter is more computationally efficient for networks that reuse the parameters many times for every forward pass (convolutions, attention layers, but not dense layers…) and generally uses less memory.
However, the on the fly implementation relies on some JAX compiler behaviour, so it might at times have worse performance. We suggest you check on your specific model. For a more detailed explanation of the on-the-fly implementation of the NTK, we refer to Novak et Al 2022. The algorithm netket uses is the layer-wise jacobian contraction method (sec 3.2) of the manuscript.
The default choice is to use the
on_the_fly=True
mode.References
Stochastic Reconfiguration was originally introduced in the QMC field by Sorella. The method was later shown to be equivalent to the Natural Gradient Descent method introduced by Amari for the Fubini-Study metric.
The kernel trick which makes NGD/SR feasible in the large-parameter count limit was originally introduced to the field of NQS by Chen & Heyl under the name of minSR. Rende & Al proposed a simpler derivation in terms of the Kernel trick.
It’s interesting to note that those tricks were first mentioned by Ren & Goldfarb in the ML community.
When using Momentum you should cite G.Goldshlager et Al. (2024).
- Inheritance
- __init__(hamiltonian, optimizer, *, diag_shift, proj_reg=None, momentum=None, linear_solver_fn=<function cholesky>, variational_state=None, chunk_size_bwd=None, mode=None, use_ntk=None, on_the_fly=None)[source]#
Initialize the driver with the given arguments.
Warning
The optimizer should be an instance of optax.sgd. Other optimizers, while they might work, will not make mathematical sense in the context of the SR/NGD optimization.
- Parameters:
hamiltonian (
AbstractOperator
) – The Hamiltonian of which the ground-state is to be found.optimizer (
Any
) – The optimizer to use for the parameter updates. To perform proper SR/NGD optimization this should be an instance of optax.sgd, but can be any other optimizer if you are brave.variational_state (
MCState
) – The variational state to optimize.diag_shift (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,bool
,number
,float
,int
]],Union
[Array
,ndarray
,bool
,number
,float
,int
]]]) – The diagonal regularization parameter \(\lambda\) for the QGT/NTK.proj_reg (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,bool
,number
,float
,int
]],Union
[Array
,ndarray
,bool
,number
,float
,int
]],None
]) – The regularization parameter for the projection of the updates. (This usually is not very important and can be left to None)momentum (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,bool
,number
,float
,int
]],Union
[Array
,ndarray
,bool
,number
,float
,int
]],None
]) – (SPRING, disabled by default, read above for details) a number between [0,1] that specifies the damping factor of the previous updates and works somewhat similarly to the beta parameter of ADAM. The maximum amplification of the step size in SPRING is \(A(\mu)=1/\sqrt{1-μ^2}\) Thus the amplification is at most a factor of \(A(0.9)=2.3\) or \(A(0.99)=7.1\). Values aroundmomentum = 0.8
empirically work well. (Defaults to None)linear_solver_fn (
Callable
[[Union
[ndarray
,Array
],Union
[ndarray
,Array
]],Union
[ndarray
,Array
]]) – The linear solver function to use for the NGD solver.mode (
JacobianMode
|None
) – The mode used to compute the jacobian of the variational state. Can be ‘real’ or ‘complex’. Real can be used for real-valued wavefunctions with a sign, to truncate the arbitrary phase of the wavefunction. This leads to lower computational cost.on_the_fly (
bool
|None
) – Whether to compute the QGT or NTK using lazy evaluation methods. This usually requires less memory. (Defaults to None, which will automatically chose the potentially best method).chunk_size_bwd (
int
|None
) – The number of rows of the NTK or of the Jacobian evaluated in a single sweep.use_ntk (
bool
|None
) – Wheter to compute the updates using the Neural Tangent Kernel (NTK) instead of the Quantum Geometric Tensor (QGT), aka switching between SR and minSR. (Defaults to None, which will automatically choose the best method)
- Attributes
- chunk_size_bwd#
Chunk size for backward-mode differentiation. This reduces memory pressure at a potential cost of higher computation time.
If computing the jacobian, the jacobian is computed in blocks of chunk_size_bwd rows. If computing the NTK lazily, this is the number of rows of NTK evaluated in a single sweep. The chunk size does not affect the result, up to numerical precision.
-
diag_shift:
Union
[Any
,Callable
[[Union
[Array
,ndarray
,bool
,number
,float
,int
]],Union
[Array
,ndarray
,bool
,number
,float
,int
]]] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'pytree_node': True, 'ignore': False, 'serialize': False, 'cache': False, 'sharded': False}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)# The diagonal shift \(\lambda\) in the curvature matrix.
This can be a scalar or a schedule. If it is a schedule, it should be a function that takes the current step as input and returns the value of the shift.
- mode#
The mode used to compute the jacobian of the variational state. Can be ‘real’, ‘complex’, or ‘onthefly’.
‘real’ mode truncates imaginary part of the wavefunction, useful for real-valued wf with a sign.
‘complex’ is the general implementation that always works.
onthefly uses a lazy implementation of the neural tangent kernel and does not compute the jacobian.
This internally uses
netket.jax.jacobian()
. See that function for a more complete documentation.
-
momentum:
bool
= Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'pytree_node': True, 'ignore': False, 'serialize': False, 'cache': False, 'sharded': False}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)# Flag specifying whether to use momentum in the optimisation.
If True, the optimizer will use momentum to accumulate previous updates following the SPRING optimizer from G.Goldshlager, N.Abrahamsen and L.Lin to accumulate previous updates for better approximation of the exact SR with no significant performance penalty.
- on_the_fly#
Whether
- optimizer#
The optimizer used to update the parameters at every iteration.
-
proj_reg:
Union
[Any
,Callable
[[Union
[Array
,ndarray
,bool
,number
,float
,int
]],Union
[Array
,ndarray
,bool
,number
,float
,int
]]] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'pytree_node': True, 'ignore': False, 'serialize': False, 'cache': False, 'sharded': False}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- state#
Returns the machine that is optimized by this driver.
- step_count#
Returns a monotonic integer labelling all the steps performed by this driver. This can be used, for example, to identify the line in a log file.
- use_ntk#
Whether to use the Neural Tangent Kernel (NTK) instead of the Quantum Geometric Tensor (QGT) to compute the update.
- Methods
- advance(steps=1)[source]#
Performs steps optimization steps.
- Parameters:
steps (
int
) – (Default=1) number of steps.
- estimate(observables)[source]#
Return MCMC statistics for the expectation value of observables in the current state of the driver.
- Parameters:
observables – A pytree of operators for which statistics should be computed.
- Returns:
A pytree of the same structure as the input, containing MCMC statistics for the corresponding operators as leaves.
- iter(n_steps, step=1)[source]#
Returns a generator which advances the VMC optimization, yielding after every step_size steps.
- reset()[source]#
Resets the driver.
Subclasses should make sure to call
super().reset()
to ensure that the step count is set to 0.
- run(n_iter, out=(), obs=None, step_size=1, show_progress=True, save_params_every=50, write_every=50, callback=<function AbstractVariationalDriver.<lambda>>, timeit=False)[source]#
Runs this variational driver, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger.
It is possible to control more specifically what quantities are logged, when to stop the optimisation, or to execute arbitrary code at every step by specifying one or more callbacks, which are passed as a list of functions to the keyword argument callback.
Callbacks are functions that follow this signature:
def callback(step, log_data, driver) -> bool: ... return True/False
If a callback returns True, the optimisation continues, otherwise it is stopped. The log_data is a dictionary that can be modified in-place to change what is logged at every step. For example, this can be used to log additional quantities such as the acceptance rate of a sampler.
Loggers are specified as an iterable passed to the keyword argument out. If only a string is specified, this will create by default a
nk.logging.JsonLog
. To know about the output format check its documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output.When running among multiple JAX devices, the logging logic is executed on all nodes, but only root-rank loggers should write to files or do expensive I/O operations.
Note
Before NetKet 3.15, loggers where automatically ‘ignored’ on non-root ranks. However, starting with NetKet 3.15 it is the responsability of a logger to check if it is executing on a non-root rank, and to ‘do nothing’ if that is the case.
The change was required to work correctly and efficiently with sharding. It will only affect users that were defining custom loggers themselves.
- Parameters:
n_iter (
int
) – the total number of iterations to be performed during this run.out (
AbstractLog
|Iterable
[AbstractLog
] |str
|None
) – A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger.obs (
dict
[str
,AbstractObservable
] |None
) – An iterable containing all observables that should be computedstep_size (
int
) – Every how many steps should observables be logged to disk (default=1)callback (
Callable
[[int
,dict
,AbstractVariationalDriver
],bool
] |Iterable
[Callable
[[int
,dict
,AbstractVariationalDriver
],bool
]]) – Callable or list of callable callback functions to stop training given a conditionshow_progress (
bool
) – If true displays a progress bar (default=True)save_params_every (
int
) – Every how many steps the parameters of the network should be serialized to disk (ignored if logger is provided)write_every (
int
) – Every how many steps the json data should be flushed to disk (ignored if logger is provided)timeit (
bool
) – If True, provide timing information.