Initialises a PRNGKey using an optional starting seed.
If using sharding, the returned key will be replicated on every process.
- Parameters:
seed (Union
[int
, Any
, None
]) – An optional integer value to use as seed
root (int
) – the master rank, used when running with multiple nodes (default 0)
- Return type:
Any
- Returns:
A sharded/broadcasted jax.random.PRNGKey()
.