[WIP] Discrete MCMC with JAX and Numpyro #29
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Added a simple discrete MCMC method that can work with any energy function. A new proposal now flips multiple spins that can vary (from flipping no spins to all of them). Previously we were only flipping one spin and then deciding to accept or reject the proposal. We can try both and see what works the best (the old approach is commented).
This numpyro implementation seems very fast so it should help us speedup training of generative models with CD. We can create a large number of initial states (chains) and sample them in parallel.
Also added some basic tests for a known energy function that just sums up all the spins. The posterior samples for such an energy function is simple, all the spins should be -1, -1, .... so it forms a basic test case.