API#
- class pathint.PathIntegralSampler(get_log_mu, x_size, t1, dt0, solver=Euler(), brownian_motion_tol=0.001)[source]#
Bases:
objectClass defining loss and sampling functions for the path integral sampler.
This approach consists of a training objective and sampling procedure for optimal control of the stochastic process
\[\mathrm{d}\mathbf{x}_t = \mathbf{u}_t \mathrm{d}t + \mathrm{d}\mathbf{w}_t ,\]where \(\mathbf{w}_t\) is a Wiener process. A network trained to find the control policy \(\mathbf{u}_t(t, \mathbf{x})\) such that the loss function is minimized causes the above process to yield samples at time \(T\) with the prespecified distribution \(\mu(\cdot)\). (Distributions and quantities at time \(t=T\) are often referred to as “terminal”.) The procedure also yields importance sampling weights \(w\).
Notes
As explained in the paper, the control policy network is trained by constructing an SDE augmented by the trajectory’s cost. This implementation uses a similar trick to simultaneously sample and compute importance sampling weights using any SDE solver.
- get_diffusion_sampling(t, x, model)[source]#
Gets the diffusion coefficient for sampling.
- Parameters
t (
Array) – time.x (
Array) – position.model (
Callable[[Array,Array],Array]) – control policy network taking t and x as arguments.
- Return type
Array
- get_diffusion_train(t, x, _)[source]#
Gets the diffusion coefficient for the training SDE.
- Parameters
t (
Array) – time.x (
Array) – state variable, with x[:-1] corresponding to \(x_t\) and x[-1] corresponding to the trajectory’s cost (\(y_t\) in the paper)._ – unused argument required by diffrax.
- Return type
Array
- get_drift(t, x, model)[source]#
Gets the drift coefficient for augmented SDE.
- Parameters
t (
Array) – time.x (
Array) – state variable, with x[:-1] corresponding to \(x_t\) and x[-1] corresponding to the trajectory’s cost (\(y_t\) in the paper).model (
Callable[[Array,Array],Array]) – control policy network taking \(t\) and \(x_t\) as arguments.
- Return type
Array
- get_log_mu: Callable[[Array], Array][source]#
\(\log \mu(x)\), the log of the (unnormalized) terminal density to be sampled from.
- get_log_mu_0(x)[source]#
Gets log probability for the terminal distribution of the uncontrolled process.
- Return type
Array
- get_loss(model, key)[source]#
Gets loss for a single trajectory.
- Parameters
model (
PyTree) – control policy network taking t and x as arguments.key (
PRNGKeyArray) – PRNG key for the trajectory.
- Returns
- approximation to \(\int_{t_0}^{t_1} \mathrm{d}t \frac{1}{2} \mathbf{u}_t(t, \mathbf{x}_t ; \theta) + \Psi(\mathbf{x}_T)\),
where the second term is the terminal cost specified by the training procedure.
- Return type
cost
- get_sample(model, key)[source]#
Generates a sample. To generate multiple samples, vmap over key.
- Parameters
model (
PyTree) – control policy network taking t and x as arguments.key (
PRNGKeyArray) – PRNG key for the trajectory.
- Returns
sample. log_w: log of the importance sampling weight.
- Return type
x_T