Part 3: Monte Carlo Sampling#
In this final tutorial, we will:
Implement Monte Carlo sampling for larger systems
Compute local energies using sparse operator connections
Estimate energies and gradients from samples
Build a complete VMC optimization loop
Explore advanced optimizers and future extensions
This tutorial builds on the concepts from Parts 1 and 2, extending to the Monte Carlo regime where full summation over the Hilbert space is no longer feasible.
Note
If you are executing this notebook on Colab, you will need to install NetKet:
# %pip install --quiet netket
1. Setup from Previous Tutorials#
Letβs recreate the complete system from Parts 1 and 2:
# Define the system
L = 4
g = nk.graph.Hypercube(length=L, n_dim=2, pbc=True)
hi = nk.hilbert.Spin(s=1 / 2, N=g.n_nodes)
# Build the Hamiltonian
hamiltonian = nk.operator.LocalOperator(hi)
# Add transverse field terms
for site in g.nodes():
hamiltonian = hamiltonian - 1.0 * nk.operator.spin.sigmax(hi, site)
# Add Ising interaction terms
for i, j in g.edges():
hamiltonian = hamiltonian + nk.operator.spin.sigmaz(
hi, i
) @ nk.operator.spin.sigmaz(hi, j)
# Convert to JAX format
hamiltonian_jax = hamiltonian.to_pauli_strings().to_jax_operator()
# Compute exact ground state for comparison
from scipy.sparse.linalg import eigsh
e_gs, psi_gs = eigsh(hamiltonian.to_sparse(), k=1)
e_gs = e_gs[0]
psi_gs = psi_gs.reshape(-1)
print(f"Exact ground state energy: {e_gs:.6f}")
Exact ground state energy: -34.010598
2. Variational Models from Part 2#
Letβs redefine our variational ansΓ€tze:
# Mean Field Ansatz
class MF(nn.Module):
@nn.compact
def __call__(self, x):
lam = self.param("lambda", nn.initializers.normal(), (1,), float)
p = nn.log_sigmoid(lam * x)
return 0.5 * jnp.sum(p, axis=-1)
# Jastrow Ansatz
class Jastrow(nn.Module):
@nn.compact
def __call__(self, x):
n_sites = x.shape[-1]
J = self.param("J", nn.initializers.normal(), (n_sites, n_sites), float)
dtype = jax.numpy.promote_types(J.dtype, x.dtype)
J = J.astype(dtype)
x = x.astype(dtype)
J_symm = J.T + J
return jnp.einsum("...i,ij,...j", x, J_symm, x)
3. Monte Carlo Sampling#
For larger problems, we cannot sum over the whole Hilbert space. Instead, we use Monte Carlo sampling to generate configurations according to \(|\psi(\sigma)|^2\).
3.1 Setting up the Sampler#
We use a Metropolis sampler that proposes new states by flipping individual spins:
sampler = nk.sampler.MetropolisSampler(
hi, # the hilbert space to be sampled
nk.sampler.rules.LocalRule(), # the transition rule
n_chains=20, # number of parallel chains
)
3.2 Generating Samples#
Samplers are used as follows:
Initialize the sampler state
Reset when changing parameters
Call
sample
to generate new configurations
# Example with Mean Field model
model = MF()
parameters = model.init(jax.random.key(0), np.ones((hi.size,)))
# Initialize sampler state
sampler_state = sampler.init_state(model, parameters, seed=1)
sampler_state = sampler.reset(model, parameters, sampler_state)
# Generate samples
samples, sampler_state = sampler.sample(
model, parameters, state=sampler_state, chain_length=100
)
print(f"Sample shape: {samples.shape}")
# Dimensions: (n_chains, chain_length, n_sites)
# Note: chains are sometimes referred to as walkers
Sample shape: (5, 100, 16)
4. Computing Local Energies#
We want to compute the energy as an expectation value:
where \(\sigma_i\) are the samples and \(E_\text{loc}\) is the local energy:
4.1 Understanding Operator Connections#
The sum over \(\eta\) is only over configurations connected to \(\sigma\) by the Hamiltonian (i.e., where \(\langle\sigma|H|\eta\rangle \neq 0\)). NetKetβs operators provide this efficiently:
# Example: get connections for a single configuration
sigma = hi.random_state(jax.random.key(1))
eta, H_sigmaeta = hamiltonian_jax.get_conn_padded(sigma)
print(f"Input configuration shape: {sigma.shape}")
print(f"Connected configurations shape: {eta.shape}")
print(f"Matrix elements shape: {H_sigmaeta.shape}")
# For this Hamiltonian, each site connects to itself (diagonal) and its neighbors
Input configuration shape: (16,)
Connected configurations shape: (17, 16)
Matrix elements shape: (17,)
This also works for batches of configurations:
sigma_batch = hi.random_state(jax.random.key(1), (4, 5))
eta_batch, H_batch = hamiltonian_jax.get_conn_padded(sigma_batch)
print(f"Batch input shape: {sigma_batch.shape}")
print(f"Batch connected configurations shape: {eta_batch.shape}")
print(f"Batch matrix elements shape: {H_batch.shape}")
Batch input shape: (4, 5, 16)
Batch connected configurations shape: (4, 5, 17, 16)
Batch matrix elements shape: (4, 5, 17)
5. Exercise: Computing Local Energies#
Implement a function to compute local energies using the connection information:
def compute_local_energies(model, parameters, hamiltonian_jax, sigma):
eta, H_sigmaeta = hamiltonian_jax.get_conn_padded(sigma)
logpsi_sigma = model.apply(parameters, sigma)
logpsi_eta = model.apply(parameters, eta)
logpsi_sigma = jnp.expand_dims(logpsi_sigma, -1)
res = jnp.sum(H_sigmaeta * jnp.exp(logpsi_eta - logpsi_sigma), axis=-1)
return res
Test your implementation:
# Uncomment after implementing compute_local_energies
# assert compute_local_energies(model, parameters, hamiltonian_jax, samples[0]).shape == samples.shape[1:-1]
# assert compute_local_energies(model, parameters, hamiltonian_jax, samples).shape == samples.shape[:-1]
# Check that it JIT compiles
# jax.jit(compute_local_energies, static_argnames='model')(model, parameters, hamiltonian_jax, sigma)
# print("compute_local_energies implementation is correct!")
6. Exercise: Estimating Energy from Samples#
Write a function that estimates the energy and its statistical error from samples. The error is given by:
@partial(jax.jit, static_argnames='model')
def estimate_energy(model, parameters, hamiltonian_jax, sigma):
E_loc = compute_local_energies(model, parameters, hamiltonian_jax, sigma)
E_average = jnp.mean(E_loc)
E_variance = jnp.var(E_loc)
E_error = jnp.sqrt(E_variance / E_loc.size)
return nk.stats.Stats(mean=E_average, error_of_mean=E_error, variance=E_variance)
Test the energy estimation:
# Uncomment after implementing estimate_energy
# energy_estimate = estimate_energy(model, parameters, hamiltonian_jax, samples)
# print("Energy estimate:", energy_estimate)
Letβs verify our Monte Carlo estimate against the exact calculation by generating more samples:
# Uncomment after implementing functions
# samples_many, sampler_state = sampler.sample(model, parameters, state=sampler_state, chain_length=5000)
# Compare with full summation from Part 2
# def compute_energy_exact(model, parameters, hamiltonian_sparse):
# all_configurations = hi.all_states()
# logpsi = model.apply(parameters, all_configurations)
# psi = jnp.exp(logpsi)
# psi = psi / jnp.linalg.norm(psi)
# return psi.conj().T @ (hamiltonian_sparse @ psi)
# hamiltonian_sparse = hamiltonian.to_sparse()
# exact_energy = compute_energy_exact(model, parameters, hamiltonian_sparse)
# mc_estimate = estimate_energy(model, parameters, hamiltonian_jax, samples_many)
# print(f"Exact calculation: {exact_energy:.6f}")
# print(f"MC estimate: {mc_estimate}")
7. Gradient Estimation with Monte Carlo#
The gradient of the energy can be estimated using:
We can compute this efficiently using JAXβs vector-Jacobian product (VJP).
7.1 Understanding the Jacobian#
Think of \(\nabla_k \log\psi(\sigma_i)\) as the JACOBIAN of the function \(\log\psi_\sigma : \mathbb{R}^{N_\text{pars}} \rightarrow \mathbb{R}^{N_\text{samples}}\):
# Example with Jastrow model
model_jastrow = Jastrow()
parameters_jastrow = model_jastrow.init(
jax.random.key(0), hi.random_state(jax.random.key(0))
)
# Reshape samples to a vector
sigma_vector = samples.reshape(-1, hi.size)
# Define the function to differentiate
logpsi_sigma_fun = lambda pars: model_jastrow.apply(pars, sigma_vector)
print(f"Input parameters shape: {jax.tree.map(lambda x: x.shape, parameters_jastrow)}")
print(f"Output shape: {logpsi_sigma_fun(parameters_jastrow).shape}")
# We can compute the Jacobian
jacobian = jax.jacrev(logpsi_sigma_fun)(parameters_jastrow)
print(f"Jacobian shape: {jax.tree.map(lambda x: x.shape, jacobian)}")
Input parameters shape: {'params': {'J': (16, 16)}}
Output shape: (500,)
Jacobian shape: {'params': {'J': (500, 16, 16)}}
8. Exercise: Energy and Gradient Estimation#
Implement a function that computes both energy and gradient estimates using VJP:
@partial(jax.jit, static_argnames='model')
def estimate_energy_and_gradient(model, parameters, hamiltonian_jax, sigma):
# reshape the samples to a vector of samples with no extra batch dimensions
sigma = sigma.reshape(-1, sigma.shape[-1])
E_loc = compute_local_energies(model, parameters, hamiltonian_jax, sigma)
# compute the energy as well
E_average = jnp.mean(E_loc)
E_variance = jnp.var(E_loc)
E_error = jnp.sqrt(E_variance/E_loc.size)
E = nk.stats.Stats(mean=E_average, error_of_mean=E_error, variance=E_variance)
# compute the gradient using VJP
logpsi_sigma_fun = lambda pars : model.apply(pars, sigma)
_, vjpfun = jax.vjp(logpsi_sigma_fun, parameters)
E_grad = vjpfun((E_loc - E_average)/E_loc.size)[0]
return E, E_grad
9. Exercise: Complete VMC Optimization Loop#
Now implement a complete VMC optimization using Monte Carlo sampling:
# Settings
model = Jastrow() # Try both MF() and Jastrow()
sampler = nk.sampler.MetropolisSampler(
hi,
nk.sampler.rules.LocalRule(),
n_chains=20
)
n_iters = 300
chain_length = 1000 // sampler.n_chains
# Initialize
parameters = model.init(jax.random.key(0), np.ones((hi.size,)))
sampler_state = sampler.init_state(model, parameters, seed=1)
# Logging
logger = nk.logging.RuntimeLog()
for i in tqdm(range(n_iters)):
# sample
sampler_state = sampler.reset(model, parameters, state=sampler_state)
samples, sampler_state = sampler.sample(model, parameters, state=sampler_state, chain_length=chain_length)
# compute energy and gradient
E, E_grad = estimate_energy_and_gradient(model, parameters, hamiltonian_jax, samples)
# update parameters
parameters = jax.tree.map(lambda x,y: x-0.005*y, parameters, E_grad)
# log energy
logger(step=i, item={'Energy':E})
0%| | 0/300 [00:00<?, ?it/s]
0%|β | 1/300 [00:00<02:03, 2.43it/s]
6%|βββββββββ | 18/300 [00:00<00:06, 45.33it/s]
11%|βββββββββββββββββ | 34/300 [00:00<00:03, 75.14it/s]
17%|βββββββββββββββββββββββββ | 51/300 [00:00<00:02, 100.42it/s]
22%|ββββββββββββββββββββββββββββββββ | 67/300 [00:00<00:01, 116.96it/s]
28%|ββββββββββββββββββββββββββββββββββββββββ | 83/300 [00:00<00:01, 128.88it/s]
33%|ββββββββββββββββββββββββββββββββββββββββββββββββ | 100/300 [00:01<00:01, 139.70it/s]
39%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 116/300 [00:01<00:01, 143.07it/s]
44%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 133/300 [00:01<00:01, 149.83it/s]
50%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 149/300 [00:01<00:01, 149.66it/s]
55%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 166/300 [00:01<00:00, 152.96it/s]
61%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 183/300 [00:01<00:00, 156.84it/s]
66%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 199/300 [00:01<00:00, 156.05it/s]
72%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 215/300 [00:01<00:00, 154.68it/s]
77%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 232/300 [00:01<00:00, 157.04it/s]
83%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 249/300 [00:01<00:00, 159.50it/s]
89%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 266/300 [00:02<00:00, 159.70it/s]
94%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 283/300 [00:02<00:00, 161.86it/s]
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 300/300 [00:02<00:00, 164.15it/s]
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 300/300 [00:02<00:00, 131.90it/s]
Plot the optimization results:
# Uncomment after running optimization
# plt.figure(figsize=(12, 5))
# plt.subplot(1, 2, 1)
# plt.plot(logger.data['Energy']['iters'], logger.data['Energy']['Mean'])
# plt.axhline(y=e_gs, color='r', linestyle='--', label='Exact ground state')
# plt.xlabel('Iteration')
# plt.ylabel('Energy')
# plt.title('VMC Energy vs Iteration')
# plt.legend()
# plt.subplot(1, 2, 2)
# plt.semilogy(logger.data['Energy']['iters'], np.abs(logger.data['Energy']['Mean'] - e_gs))
# plt.xlabel('Iteration')
# plt.ylabel('|Energy - Exact|')
# plt.title('Error vs Iteration (log scale)')
# plt.tight_layout()
10. Advanced Topics and Extensions#
10.1 Better Optimizers with Optax#
You can use more sophisticated optimizers from the optax library:
import optax
# Example optimization loop with Adam
def optimize_with_adam():
# Define optimizer
optimizer = optax.adam(learning_rate=0.01)
# Initialize
model = Jastrow()
parameters = model.init(jax.random.key(0), np.ones((hi.size,)))
optimizer_state = optimizer.init(parameters)
logger = nk.logging.RuntimeLog()
for i in tqdm(range(100)):
# Sample and compute gradients (same as before)
# samples, sampler_state = ...
# E, E_grad = estimate_energy_and_gradient(...)
# Update with Adam
# updates, optimizer_state = optimizer.update(E_grad, optimizer_state, parameters)
# parameters = optax.apply_updates(parameters, updates)
# logger(step=i, item={'Energy': E})
pass
10.2 Feed-Forward Neural Networks#
Try implementing a more complex ansatz using feed-forward networks:
class FeedForward(nn.Module):
hidden_size: int = 32
@nn.compact
def __call__(self, x):
# Use nn.Dense layers with relu activation
x = nn.Dense(self.hidden_size)(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_size)(x)
return jnp.sum(x, axis=-1) # Pool over sites
10.3 Comparison of Different AnsΓ€tze#
Compare the performance of different variational ansΓ€tze:
Summary#
In this tutorial, you learned:
How to implement Monte Carlo sampling for VMC calculations
How to compute local energies using operator connections
How to estimate energies and gradients from samples
How to build complete VMC optimization loops
How to use advanced optimizers and neural network architectures
You now have the tools to quickly get started in running VMC calculations without worrying about the implementation details of sampling and operators. This provides a foundation for implementing more advanced techniques like:
Stochastic Reconfiguration (Natural Gradients)
Time evolution and dynamics
More sophisticated neural network architectures
Multi-GPU distributed calculations
The modular design allows you to focus on the physics and machine learning aspects while NetKet handles the computational infrastructure.