HJ reachability basics#

import jax
import jax.numpy as jnp
import numpy as np

from IPython.display import HTML
import matplotlib.animation as anim
import matplotlib.pyplot as plt
!pip install tqdm
Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1
!pip install hj_reachability
import hj_reachability as hj
Collecting hj_reachability
  Downloading hj_reachability-0.7.0-py3-none-any.whl.metadata (3.1 kB)
Collecting flax>=0.6.6 (from hj_reachability)
  Downloading flax-0.7.2-py3-none-any.whl.metadata (10.0 kB)
INFO: pip is looking at multiple versions of hj-reachability to determine which version is compatible with other requirements. This could take a while.
Collecting hj_reachability
  Downloading hj_reachability-0.6.0-py3-none-any.whl.metadata (3.1 kB)
Requirement already satisfied: jax>=0.4.2 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from hj_reachability) (0.4.13)
Requirement already satisfied: numpy>=1.18.0 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from hj_reachability) (1.24.4)
Collecting msgpack (from flax>=0.6.6->hj_reachability)
  Downloading msgpack-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting optax (from flax>=0.6.6->hj_reachability)
  Downloading optax-0.1.8-py3-none-any.whl.metadata (14 kB)
Collecting orbax-checkpoint (from flax>=0.6.6->hj_reachability)
  Downloading orbax_checkpoint-0.2.3-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorstore (from flax>=0.6.6->hj_reachability)
  Downloading tensorstore-0.1.45-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB)
Collecting rich>=11.1 (from flax>=0.6.6->hj_reachability)
  Downloading rich-14.0.0-py3-none-any.whl.metadata (18 kB)
Requirement already satisfied: typing-extensions>=4.1.1 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from flax>=0.6.6->hj_reachability) (4.13.2)
Requirement already satisfied: PyYAML>=5.4.1 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from flax>=0.6.6->hj_reachability) (6.0.2)
Requirement already satisfied: ml-dtypes>=0.1.0 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (0.2.0)
Requirement already satisfied: opt-einsum in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (3.4.0)
Requirement already satisfied: scipy>=1.7 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (1.10.1)
Requirement already satisfied: importlib-metadata>=4.6 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (8.5.0)
Requirement already satisfied: zipp>=3.20 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from importlib-metadata>=4.6->jax>=0.4.2->hj_reachability) (3.20.2)
Collecting markdown-it-py>=2.2.0 (from rich>=11.1->flax>=0.6.6->hj_reachability)
  Downloading markdown_it_py-3.0.0-py3-none-any.whl.metadata (6.9 kB)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from rich>=11.1->flax>=0.6.6->hj_reachability) (2.19.1)
Collecting absl-py>=0.7.1 (from optax->flax>=0.6.6->hj_reachability)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting chex>=0.1.7 (from optax->flax>=0.6.6->hj_reachability)
  Downloading chex-0.1.7-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from optax->flax>=0.6.6->hj_reachability) (0.4.13)
Collecting cached_property (from orbax-checkpoint->flax>=0.6.6->hj_reachability)
  Downloading cached_property-2.0.1-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: importlib_resources in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from orbax-checkpoint->flax>=0.6.6->hj_reachability) (6.4.5)
Collecting etils (from orbax-checkpoint->flax>=0.6.6->hj_reachability)
  Downloading etils-1.3.0-py3-none-any.whl.metadata (5.5 kB)
Requirement already satisfied: nest_asyncio in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from orbax-checkpoint->flax>=0.6.6->hj_reachability) (1.6.0)
Collecting dm-tree>=0.1.5 (from chex>=0.1.7->optax->flax>=0.6.6->hj_reachability)
  Downloading dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting toolz>=0.9.0 (from chex>=0.1.7->optax->flax>=0.6.6->hj_reachability)
  Downloading toolz-1.0.0-py3-none-any.whl.metadata (5.1 kB)
Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.6.6->hj_reachability)
  Downloading mdurl-0.1.2-py3-none-any.whl.metadata (1.6 kB)
Downloading hj_reachability-0.6.0-py3-none-any.whl (23 kB)
Downloading flax-0.7.2-py3-none-any.whl (226 kB)
Downloading rich-14.0.0-py3-none-any.whl (243 kB)
Downloading msgpack-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (381 kB)
Downloading optax-0.1.8-py3-none-any.whl (199 kB)
Downloading orbax_checkpoint-0.2.3-py3-none-any.whl (81 kB)
Downloading tensorstore-0.1.45-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/13.5 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 125.9 MB/s eta 0:00:00
?25hDownloading absl_py-2.2.2-py3-none-any.whl (135 kB)
Downloading chex-0.1.7-py3-none-any.whl (89 kB)
Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
Downloading cached_property-2.0.1-py3-none-any.whl (7.4 kB)
Downloading etils-1.3.0-py3-none-any.whl (126 kB)
Downloading dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)
Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Downloading toolz-1.0.0-py3-none-any.whl (56 kB)
Installing collected packages: dm-tree, toolz, tensorstore, msgpack, mdurl, etils, cached_property, absl-py, markdown-it-py, rich, orbax-checkpoint, chex, optax, flax, hj_reachability
  Attempting uninstall: markdown-it-py
    Found existing installation: markdown-it-py 1.1.0
    Uninstalling markdown-it-py-1.1.0:
      Successfully uninstalled markdown-it-py-1.1.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mdit-py-plugins 0.2.8 requires markdown-it-py~=1.0, but you have markdown-it-py 3.0.0 which is incompatible.
myst-parser 0.15.2 requires markdown-it-py<2.0.0,>=1.0.0, but you have markdown-it-py 3.0.0 which is incompatible.
Successfully installed absl-py-2.2.2 cached_property-2.0.1 chex-0.1.7 dm-tree-0.1.8 etils-1.3.0 flax-0.7.2 hj_reachability-0.6.0 markdown-it-py-3.0.0 mdurl-0.1.2 msgpack-1.1.0 optax-0.1.8 orbax-checkpoint-0.2.3 rich-14.0.0 tensorstore-0.1.45 toolz-1.0.0

Example system: Air3D#

dynamics = hj.systems.Air3d()
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]),
                                                                           np.array([20., 10., 2 * np.pi])),
                                                               (51, 40, 50),
                                                               periodic_dims=2)
values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 5

solver_settings = hj.SolverSettings.with_accuracy("very_high",
                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

hj.step: propagate the HJ PDE from (time, values) to target_time.#

time = 0.
target_time = -2.8
target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)
  0%|          |  0.0000/2.8 [00:00<?, ?sim_s/s]
  2%|2         |  0.0616/2.8 [00:00<00:04,  1.81s/sim_s]
  4%|4         |  0.1231/2.8 [00:00<00:04,  1.82s/sim_s]
  7%|6         |  0.1847/2.8 [00:00<00:04,  1.82s/sim_s]
  9%|8         |  0.2463/2.8 [00:00<00:04,  1.81s/sim_s]
 11%|#         |  0.3079/2.8 [00:00<00:04,  1.83s/sim_s]
 13%|#3        |  0.3694/2.8 [00:00<00:04,  1.83s/sim_s]
 15%|#5        |  0.4310/2.8 [00:00<00:04,  1.84s/sim_s]
 18%|#7        |  0.4926/2.8 [00:00<00:04,  1.82s/sim_s]
 20%|#9        |  0.5542/2.8 [00:01<00:04,  1.84s/sim_s]
 22%|##1       |  0.6157/2.8 [00:01<00:04,  1.85s/sim_s]
 24%|##4       |  0.6773/2.8 [00:01<00:03,  1.84s/sim_s]
 26%|##6       |  0.7389/2.8 [00:01<00:03,  1.84s/sim_s]
 29%|##8       |  0.8004/2.8 [00:01<00:03,  1.83s/sim_s]
 31%|###       |  0.8620/2.8 [00:01<00:03,  1.83s/sim_s]
 33%|###2      |  0.9236/2.8 [00:01<00:03,  1.82s/sim_s]
 35%|###5      |  0.9852/2.8 [00:01<00:03,  1.83s/sim_s]
 37%|###7      |  1.0467/2.8 [00:01<00:03,  1.84s/sim_s]
 40%|###9      |  1.1083/2.8 [00:02<00:03,  1.85s/sim_s]
 42%|####1     |  1.1699/2.8 [00:02<00:03,  1.85s/sim_s]
 44%|####3     |  1.2315/2.8 [00:02<00:02,  1.85s/sim_s]
 46%|####6     |  1.2930/2.8 [00:02<00:02,  1.84s/sim_s]
 48%|####8     |  1.3546/2.8 [00:02<00:02,  1.85s/sim_s]
 51%|#####     |  1.4162/2.8 [00:02<00:02,  1.85s/sim_s]
 53%|#####2    |  1.4778/2.8 [00:02<00:02,  1.86s/sim_s]
 55%|#####4    |  1.5316/2.8 [00:02<00:02,  1.86s/sim_s]
 57%|#####6    |  1.5932/2.8 [00:02<00:02,  1.85s/sim_s]
 59%|#####9    |  1.6548/2.8 [00:03<00:02,  1.84s/sim_s]
 61%|######1   |  1.7163/2.8 [00:03<00:01,  1.84s/sim_s]
 63%|######3   |  1.7779/2.8 [00:03<00:01,  1.82s/sim_s]
 66%|######5   |  1.8395/2.8 [00:03<00:01,  1.83s/sim_s]
 68%|######7   |  1.9011/2.8 [00:03<00:01,  1.83s/sim_s]
 70%|#######   |  1.9626/2.8 [00:03<00:01,  1.84s/sim_s]
 72%|#######2  |  2.0242/2.8 [00:03<00:01,  1.85s/sim_s]
 74%|#######4  |  2.0858/2.8 [00:03<00:01,  1.83s/sim_s]
 77%|#######6  |  2.1474/2.8 [00:03<00:01,  1.83s/sim_s]
 79%|#######8  |  2.2089/2.8 [00:04<00:01,  1.84s/sim_s]
 81%|########1 |  2.2705/2.8 [00:04<00:00,  1.84s/sim_s]
 83%|########3 |  2.3321/2.8 [00:04<00:00,  1.83s/sim_s]
 85%|########5 |  2.3937/2.8 [00:04<00:00,  1.84s/sim_s]
 88%|########7 |  2.4552/2.8 [00:04<00:00,  1.84s/sim_s]
 90%|########9 |  2.5168/2.8 [00:04<00:00,  1.84s/sim_s]
 92%|#########2|  2.5784/2.8 [00:04<00:00,  1.84s/sim_s]
 94%|#########4|  2.6399/2.8 [00:04<00:00,  1.86s/sim_s]
 96%|#########6|  2.7015/2.8 [00:04<00:00,  1.85s/sim_s]
 99%|#########8|  2.7631/2.8 [00:05<00:00,  1.84s/sim_s]
100%|##########|  2.8000/2.8 [00:05<00:00,  1.84s/sim_s]

plt.jet()
plt.figure(figsize=(13, 8))
plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], target_values[:, :, 30].T)
plt.colorbar()
plt.contour(grid.coordinate_vectors[0],
            grid.coordinate_vectors[1],
            target_values[:, :, 30].T,
            levels=0,
            colors="black",
            linewidths=3)
<matplotlib.contour.QuadContourSet at 0x7fc25ca21b50>
<Figure size 640x480 with 0 Axes>
../_images/hj_reachability_basics_8_2.png

Value evaluation#

state = grid.states[4,5,4] # define a state to evaluate the value function at

# use grid.interpolate to evaluate the value function at an interpolated state
V_value = grid.interpolate(target_values, state) # should == target_values[4,5,4]
V_value
Array(2.681385, dtype=float32)

Gradient of value function evaluation#

# perform central differencing over target_values
dV_values =  grid.grad_values(target_values)
dV_values.shape # shape is  [grid_size x state_dim]
(51, 40, 50, 3)
# use grid.interpolate to evaluate the gradient of value function at an interpolated state
grad_value = grid.interpolate(dV_values, state) # should == dV_values[4,5,4]
grad_value
Array([-0.57325226, -0.7299447 , -0.3905369 ], dtype=float32)

Compute optimal policy#

state = grid.states[4,5,4] # define a state to evaluate policy at

# optimal_control_and_disturbance(self, state, time, grad_value)
a_opt, b_opt = dynamics.optimal_control_and_disturbance(state, 0., grad_value)
a_opt, b_opt
(Array([1.], dtype=float32), Array([1.], dtype=float32))