In [1]:
import astropy.units as u

In [3]:
import astropy.constants as c
import numpy as np

In [5]:
v = 2.*np.pi*u.au/u.year

In [6]:
v

<Quantity 6.28318531 AU / yr>

In [7]:
E_kin = 0.5*v**2 * c.M_earth

In [8]:
E_pot = c.G*c.M_earth*c.M_sun/u.au

In [10]:
(E_kin/E_pot).to(1)

<Quantity 0.50001889>

In [11]:
!pip install -U "jax[cpu]"



In [12]:
import jax

In [13]:
import jax.numpy as jnp

In [14]:
def pot(r):
    rs = jnp.sqrt(r[0]**2 + r[1]**2 + r[2]**2)
    return -1.0/rs

In [15]:
pot([1,0,0])

Array(-1., dtype=float32, weak_type=True)

In [17]:
force = jax.jacfwd(pot)

In [20]:
force([1.,1.,0.])

[Array(0.3535534, dtype=float32, weak_type=True),
 Array(0.3535534, dtype=float32, weak_type=True),
 Array(0., dtype=float32, weak_type=True)]

In [22]:
force_der = jax.jacfwd(force)

In [24]:
force_der([1.,0.,0.])

[[Array(-2., dtype=float32, weak_type=True),
  Array(0., dtype=float32, weak_type=True),
  Array(0., dtype=float32, weak_type=True)],
 [Array(0., dtype=float32, weak_type=True),
  Array(1., dtype=float32, weak_type=True),
  Array(0., dtype=float32, weak_type=True)],
 [Array(0., dtype=float32, weak_type=True),
  Array(0., dtype=float32, weak_type=True),
  Array(1., dtype=float32, weak_type=True)]]

In [28]:
def curl_force(r):
    force_der = jax.jacfwd(force)(r)
    return [force_der[1][2] - force_der[2][1], force_der[2][0] - force_der[0][2], force_der[0][1] - force_der[1][0]]

In [35]:
curl_force(np.array([1.213,1.99,0]))

[Array(0., dtype=float32), Array(0., dtype=float32), Array(0., dtype=float32)]