HJ reachability basics
Contents
HJ reachability basics#
import jax
import jax.numpy as jnp
import numpy as np
from IPython.display import HTML
import matplotlib.animation as anim
import matplotlib.pyplot as plt
!pip install tqdm
Collecting tqdm
Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1
!pip install hj_reachability
import hj_reachability as hj
Collecting hj_reachability
Downloading hj_reachability-0.7.0-py3-none-any.whl.metadata (3.1 kB)
Collecting flax>=0.6.6 (from hj_reachability)
Downloading flax-0.7.2-py3-none-any.whl.metadata (10.0 kB)
INFO: pip is looking at multiple versions of hj-reachability to determine which version is compatible with other requirements. This could take a while.
Collecting hj_reachability
Downloading hj_reachability-0.6.0-py3-none-any.whl.metadata (3.1 kB)
Requirement already satisfied: jax>=0.4.2 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from hj_reachability) (0.4.13)
Requirement already satisfied: numpy>=1.18.0 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from hj_reachability) (1.24.4)
Collecting msgpack (from flax>=0.6.6->hj_reachability)
Downloading msgpack-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting optax (from flax>=0.6.6->hj_reachability)
Downloading optax-0.1.8-py3-none-any.whl.metadata (14 kB)
Collecting orbax-checkpoint (from flax>=0.6.6->hj_reachability)
Downloading orbax_checkpoint-0.2.3-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorstore (from flax>=0.6.6->hj_reachability)
Downloading tensorstore-0.1.45-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB)
Collecting rich>=11.1 (from flax>=0.6.6->hj_reachability)
Downloading rich-14.0.0-py3-none-any.whl.metadata (18 kB)
Requirement already satisfied: typing-extensions>=4.1.1 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from flax>=0.6.6->hj_reachability) (4.13.2)
Requirement already satisfied: PyYAML>=5.4.1 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from flax>=0.6.6->hj_reachability) (6.0.2)
Requirement already satisfied: ml-dtypes>=0.1.0 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (0.2.0)
Requirement already satisfied: opt-einsum in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (3.4.0)
Requirement already satisfied: scipy>=1.7 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (1.10.1)
Requirement already satisfied: importlib-metadata>=4.6 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from jax>=0.4.2->hj_reachability) (8.5.0)
Requirement already satisfied: zipp>=3.20 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from importlib-metadata>=4.6->jax>=0.4.2->hj_reachability) (3.20.2)
Collecting markdown-it-py>=2.2.0 (from rich>=11.1->flax>=0.6.6->hj_reachability)
Downloading markdown_it_py-3.0.0-py3-none-any.whl.metadata (6.9 kB)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from rich>=11.1->flax>=0.6.6->hj_reachability) (2.19.1)
Collecting absl-py>=0.7.1 (from optax->flax>=0.6.6->hj_reachability)
Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting chex>=0.1.7 (from optax->flax>=0.6.6->hj_reachability)
Downloading chex-0.1.7-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from optax->flax>=0.6.6->hj_reachability) (0.4.13)
Collecting cached_property (from orbax-checkpoint->flax>=0.6.6->hj_reachability)
Downloading cached_property-2.0.1-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: importlib_resources in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from orbax-checkpoint->flax>=0.6.6->hj_reachability) (6.4.5)
Collecting etils (from orbax-checkpoint->flax>=0.6.6->hj_reachability)
Downloading etils-1.3.0-py3-none-any.whl.metadata (5.5 kB)
Requirement already satisfied: nest_asyncio in /usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages (from orbax-checkpoint->flax>=0.6.6->hj_reachability) (1.6.0)
Collecting dm-tree>=0.1.5 (from chex>=0.1.7->optax->flax>=0.6.6->hj_reachability)
Downloading dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting toolz>=0.9.0 (from chex>=0.1.7->optax->flax>=0.6.6->hj_reachability)
Downloading toolz-1.0.0-py3-none-any.whl.metadata (5.1 kB)
Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.6.6->hj_reachability)
Downloading mdurl-0.1.2-py3-none-any.whl.metadata (1.6 kB)
Downloading hj_reachability-0.6.0-py3-none-any.whl (23 kB)
Downloading flax-0.7.2-py3-none-any.whl (226 kB)
Downloading rich-14.0.0-py3-none-any.whl (243 kB)
Downloading msgpack-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (381 kB)
Downloading optax-0.1.8-py3-none-any.whl (199 kB)
Downloading orbax_checkpoint-0.2.3-py3-none-any.whl (81 kB)
Downloading tensorstore-0.1.45-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/13.5 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 125.9 MB/s eta 0:00:00
?25hDownloading absl_py-2.2.2-py3-none-any.whl (135 kB)
Downloading chex-0.1.7-py3-none-any.whl (89 kB)
Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
Downloading cached_property-2.0.1-py3-none-any.whl (7.4 kB)
Downloading etils-1.3.0-py3-none-any.whl (126 kB)
Downloading dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)
Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Downloading toolz-1.0.0-py3-none-any.whl (56 kB)
Installing collected packages: dm-tree, toolz, tensorstore, msgpack, mdurl, etils, cached_property, absl-py, markdown-it-py, rich, orbax-checkpoint, chex, optax, flax, hj_reachability
Attempting uninstall: markdown-it-py
Found existing installation: markdown-it-py 1.1.0
Uninstalling markdown-it-py-1.1.0:
Successfully uninstalled markdown-it-py-1.1.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mdit-py-plugins 0.2.8 requires markdown-it-py~=1.0, but you have markdown-it-py 3.0.0 which is incompatible.
myst-parser 0.15.2 requires markdown-it-py<2.0.0,>=1.0.0, but you have markdown-it-py 3.0.0 which is incompatible.
Successfully installed absl-py-2.2.2 cached_property-2.0.1 chex-0.1.7 dm-tree-0.1.8 etils-1.3.0 flax-0.7.2 hj_reachability-0.6.0 markdown-it-py-3.0.0 mdurl-0.1.2 msgpack-1.1.0 optax-0.1.8 orbax-checkpoint-0.2.3 rich-14.0.0 tensorstore-0.1.45 toolz-1.0.0
Example system: Air3D
#
dynamics = hj.systems.Air3d()
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-6., -10., 0.]),
np.array([20., 10., 2 * np.pi])),
(51, 40, 50),
periodic_dims=2)
values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 5
solver_settings = hj.SolverSettings.with_accuracy("very_high",
hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
hj.step
: propagate the HJ PDE from (time, values)
to target_time
.#
time = 0.
target_time = -2.8
target_values = hj.step(solver_settings, dynamics, grid, time, values, target_time)
0%| | 0.0000/2.8 [00:00<?, ?sim_s/s]
2%|2 | 0.0616/2.8 [00:00<00:04, 1.81s/sim_s]
4%|4 | 0.1231/2.8 [00:00<00:04, 1.82s/sim_s]
7%|6 | 0.1847/2.8 [00:00<00:04, 1.82s/sim_s]
9%|8 | 0.2463/2.8 [00:00<00:04, 1.81s/sim_s]
11%|# | 0.3079/2.8 [00:00<00:04, 1.83s/sim_s]
13%|#3 | 0.3694/2.8 [00:00<00:04, 1.83s/sim_s]
15%|#5 | 0.4310/2.8 [00:00<00:04, 1.84s/sim_s]
18%|#7 | 0.4926/2.8 [00:00<00:04, 1.82s/sim_s]
20%|#9 | 0.5542/2.8 [00:01<00:04, 1.84s/sim_s]
22%|##1 | 0.6157/2.8 [00:01<00:04, 1.85s/sim_s]
24%|##4 | 0.6773/2.8 [00:01<00:03, 1.84s/sim_s]
26%|##6 | 0.7389/2.8 [00:01<00:03, 1.84s/sim_s]
29%|##8 | 0.8004/2.8 [00:01<00:03, 1.83s/sim_s]
31%|### | 0.8620/2.8 [00:01<00:03, 1.83s/sim_s]
33%|###2 | 0.9236/2.8 [00:01<00:03, 1.82s/sim_s]
35%|###5 | 0.9852/2.8 [00:01<00:03, 1.83s/sim_s]
37%|###7 | 1.0467/2.8 [00:01<00:03, 1.84s/sim_s]
40%|###9 | 1.1083/2.8 [00:02<00:03, 1.85s/sim_s]
42%|####1 | 1.1699/2.8 [00:02<00:03, 1.85s/sim_s]
44%|####3 | 1.2315/2.8 [00:02<00:02, 1.85s/sim_s]
46%|####6 | 1.2930/2.8 [00:02<00:02, 1.84s/sim_s]
48%|####8 | 1.3546/2.8 [00:02<00:02, 1.85s/sim_s]
51%|##### | 1.4162/2.8 [00:02<00:02, 1.85s/sim_s]
53%|#####2 | 1.4778/2.8 [00:02<00:02, 1.86s/sim_s]
55%|#####4 | 1.5316/2.8 [00:02<00:02, 1.86s/sim_s]
57%|#####6 | 1.5932/2.8 [00:02<00:02, 1.85s/sim_s]
59%|#####9 | 1.6548/2.8 [00:03<00:02, 1.84s/sim_s]
61%|######1 | 1.7163/2.8 [00:03<00:01, 1.84s/sim_s]
63%|######3 | 1.7779/2.8 [00:03<00:01, 1.82s/sim_s]
66%|######5 | 1.8395/2.8 [00:03<00:01, 1.83s/sim_s]
68%|######7 | 1.9011/2.8 [00:03<00:01, 1.83s/sim_s]
70%|####### | 1.9626/2.8 [00:03<00:01, 1.84s/sim_s]
72%|#######2 | 2.0242/2.8 [00:03<00:01, 1.85s/sim_s]
74%|#######4 | 2.0858/2.8 [00:03<00:01, 1.83s/sim_s]
77%|#######6 | 2.1474/2.8 [00:03<00:01, 1.83s/sim_s]
79%|#######8 | 2.2089/2.8 [00:04<00:01, 1.84s/sim_s]
81%|########1 | 2.2705/2.8 [00:04<00:00, 1.84s/sim_s]
83%|########3 | 2.3321/2.8 [00:04<00:00, 1.83s/sim_s]
85%|########5 | 2.3937/2.8 [00:04<00:00, 1.84s/sim_s]
88%|########7 | 2.4552/2.8 [00:04<00:00, 1.84s/sim_s]
90%|########9 | 2.5168/2.8 [00:04<00:00, 1.84s/sim_s]
92%|#########2| 2.5784/2.8 [00:04<00:00, 1.84s/sim_s]
94%|#########4| 2.6399/2.8 [00:04<00:00, 1.86s/sim_s]
96%|#########6| 2.7015/2.8 [00:04<00:00, 1.85s/sim_s]
99%|#########8| 2.7631/2.8 [00:05<00:00, 1.84s/sim_s]
100%|##########| 2.8000/2.8 [00:05<00:00, 1.84s/sim_s]
plt.jet()
plt.figure(figsize=(13, 8))
plt.contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], target_values[:, :, 30].T)
plt.colorbar()
plt.contour(grid.coordinate_vectors[0],
grid.coordinate_vectors[1],
target_values[:, :, 30].T,
levels=0,
colors="black",
linewidths=3)
<matplotlib.contour.QuadContourSet at 0x7fc25ca21b50>
<Figure size 640x480 with 0 Axes>

Value evaluation#
state = grid.states[4,5,4] # define a state to evaluate the value function at
# use grid.interpolate to evaluate the value function at an interpolated state
V_value = grid.interpolate(target_values, state) # should == target_values[4,5,4]
V_value
Array(2.681385, dtype=float32)
Gradient of value function evaluation#
# perform central differencing over target_values
dV_values = grid.grad_values(target_values)
dV_values.shape # shape is [grid_size x state_dim]
(51, 40, 50, 3)
# use grid.interpolate to evaluate the gradient of value function at an interpolated state
grad_value = grid.interpolate(dV_values, state) # should == dV_values[4,5,4]
grad_value
Array([-0.57325226, -0.7299447 , -0.3905369 ], dtype=float32)
Compute optimal policy#
state = grid.states[4,5,4] # define a state to evaluate policy at
# optimal_control_and_disturbance(self, state, time, grad_value)
a_opt, b_opt = dynamics.optimal_control_and_disturbance(state, 0., grad_value)
a_opt, b_opt
(Array([1.], dtype=float32), Array([1.], dtype=float32))