API#

class pathint.PathIntegralSampler(get_log_mu, x_size, t1, dt0, solver=Euler(), brownian_motion_tol=0.001)[source]#

Bases: object

Class defining loss and sampling functions for the path integral sampler.

This approach consists of a training objective and sampling procedure for optimal control of the stochastic process

\[\mathrm{d}\mathbf{x}_t = \mathbf{u}_t \mathrm{d}t + \mathrm{d}\mathbf{w}_t ,\]

where \(\mathbf{w}_t\) is a Wiener process. A network trained to find the control policy \(\mathbf{u}_t(t, \mathbf{x})\) such that the loss function is minimized causes the above process to yield samples at time \(T\) with the prespecified distribution \(\mu(\cdot)\). (Distributions and quantities at time \(t=T\) are often referred to as “terminal”.) The procedure also yields importance sampling weights \(w\).

Notes

As explained in the paper, the control policy network is trained by constructing an SDE augmented by the trajectory’s cost. This implementation uses a similar trick to simultaneously sample and compute importance sampling weights using any SDE solver.

brownian_motion_tol: float = 0.001[source]#

tolerance for dfx.VirtualBrownianTree.

dt0: float[source]#

initial timestep size for solver.

get_diffusion_sampling(t, x, model)[source]#

Gets the diffusion coefficient for sampling.

Parameters
  • t (Array) – time.

  • x (Array) – position.

  • model (Callable[[Array, Array], Array]) – control policy network taking t and x as arguments.

Return type

Array

get_diffusion_train(t, x, _)[source]#

Gets the diffusion coefficient for the training SDE.

Parameters
  • t (Array) – time.

  • x (Array) – state variable, with x[:-1] corresponding to \(x_t\) and x[-1] corresponding to the trajectory’s cost (\(y_t\) in the paper).

  • _ – unused argument required by diffrax.

Return type

Array

get_drift(t, x, model)[source]#

Gets the drift coefficient for augmented SDE.

Parameters
  • t (Array) – time.

  • x (Array) – state variable, with x[:-1] corresponding to \(x_t\) and x[-1] corresponding to the trajectory’s cost (\(y_t\) in the paper).

  • model (Callable[[Array, Array], Array]) – control policy network taking \(t\) and \(x_t\) as arguments.

Return type

Array

get_log_mu: Callable[[Array], Array][source]#

\(\log \mu(x)\), the log of the (unnormalized) terminal density to be sampled from.

get_log_mu_0(x)[source]#

Gets log probability for the terminal distribution of the uncontrolled process.

Return type

Array

get_loss(model, key)[source]#

Gets loss for a single trajectory.

Parameters
  • model (PyTree) – control policy network taking t and x as arguments.

  • key (PRNGKeyArray) – PRNG key for the trajectory.

Returns

approximation to \(\int_{t_0}^{t_1} \mathrm{d}t \frac{1}{2} \mathbf{u}_t(t, \mathbf{x}_t ; \theta) + \Psi(\mathbf{x}_T)\),

where the second term is the terminal cost specified by the training procedure.

Return type

cost

get_sample(model, key)[source]#

Generates a sample. To generate multiple samples, vmap over key.

Parameters
  • model (PyTree) – control policy network taking t and x as arguments.

  • key (PRNGKeyArray) – PRNG key for the trajectory.

Returns

sample. log_w: log of the importance sampling weight.

Return type

x_T

solver: AbstractSolver = Euler()[source]#

SDE solver.

t1: float[source]#

duration of diffusion.

x_size: int[source]#

size of \(x\) vector.

y0: Array[source]#

point at which diffusion begins (the origin).