diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index 8f501396..3e405a23 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -27,6 +27,7 @@ changelog `_ format. This project follows Added ##### +* Added a PyTorch implementation of the exponential integral function * Added ``dtype`` and ``device`` for ``Calculator`` classses Changed diff --git a/docs/src/references/lib/math.rst b/docs/src/references/lib/math.rst index b20c9c83..f2cde344 100644 --- a/docs/src/references/lib/math.rst +++ b/docs/src/references/lib/math.rst @@ -1,6 +1,6 @@ Math #### +.. autofunction:: torchpme.lib.exp1 .. autofunction:: torchpme.lib.gamma -.. autofunction:: torchpme.lib.torch_exp1 .. autofunction:: torchpme.lib.gammaincc_over_powerlaw diff --git a/src/torchpme/lib/__init__.py b/src/torchpme/lib/__init__.py index 47a0aa54..f21cf355 100644 --- a/src/torchpme/lib/__init__.py +++ b/src/torchpme/lib/__init__.py @@ -4,7 +4,7 @@ generate_kvectors_for_mesh, get_ns_mesh, ) -from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1 +from .math import exp1, gamma, gammaincc_over_powerlaw from .mesh_interpolator import MeshInterpolator from .splines import ( CubicSpline, @@ -28,7 +28,6 @@ "generate_kvectors_for_mesh", "get_ns_mesh", "gamma", - "CustomExp1", "gammaincc_over_powerlaw", - "torch_exp1", + "exp1", ] diff --git a/src/torchpme/lib/math.py b/src/torchpme/lib/math.py index 871abbc2..aae243dc 100644 --- a/src/torchpme/lib/math.py +++ b/src/torchpme/lib/math.py @@ -1,5 +1,4 @@ import torch -from scipy.special import exp1 from torch.special import gammaln @@ -14,40 +13,89 @@ def gamma(x: torch.Tensor) -> torch.Tensor: return torch.exp(gammaln(x)) -class CustomExp1(torch.autograd.Function): - """Custom exponential integral function Exp1(x) to have an autograd-compatible version.""" - +class _CustomExp1(torch.autograd.Function): @staticmethod - def forward(ctx, input): - ctx.save_for_backward(input) - input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy() - return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype) + def forward(ctx, x): + # this implementation is inspired by the one in scipy: + # https://github.com/scipy/scipy/blob/34d91ce06d4d05e564b79bf65288284247b1f3e3/scipy/special/xsf/expint.h#L22 + ctx.save_for_backward(x) + + # Constants + SCIPY_EULER = ( + 0.577215664901532860606512090082402431 # Euler-Mascheroni constant + ) + inf = torch.inf + + # Handle case when x == 0 + result = torch.full_like(x, inf) + mask = x > 0 + + # Compute for x <= 1 + x_small = x[mask & (x <= 1)] + if x_small.numel() > 0: + e1 = torch.ones_like(x_small) + r = torch.ones_like(x_small) + for k in range(1, 26): + r = -r * k * x_small / (k + 1.0) ** 2 + e1 += r + if torch.all(torch.abs(r) <= torch.abs(e1) * 1e-15): + break + result[mask & (x <= 1)] = -SCIPY_EULER - torch.log(x_small) + x_small * e1 + + # Compute for x > 1 + x_large = x[mask & (x > 1)] + if x_large.numel() > 0: + m = 20 + (80.0 / x_large).to(torch.int32) + t0 = torch.zeros_like(x_large) + for k in range(m.max(), 0, -1): + t0 = k / (1.0 + k / (x_large + t0)) + t = 1.0 / (x_large + t0) + result[mask & (x > 1)] = torch.exp(-x_large) * t + + return result @staticmethod def backward(ctx, grad_output): - (input,) = ctx.saved_tensors - return -grad_output * torch.exp(-input) / input + (x,) = ctx.saved_tensors + return -grad_output * torch.exp(-x) / x + +def exp1(x): + r""" + Exponential integral E1. -def torch_exp1(input): - """Wrapper for the custom exponential integral function.""" - return CustomExp1.apply(input) + For a real number :math:`x > 0` the exponential integral can be defined as + + .. math:: + + E1(x) = \int_{x}^{\infty} \frac{e^{-t}}{t} dt + + :param x: Input tensor (x > 0) + :return: Exponential integral E1(x) + """ + return _CustomExp1.apply(x) def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor: - """Function to compute the regularized incomplete gamma function complement for integer exponents.""" + """ + Compute the regularized incomplete gamma function complement for integer exponents. + + :param exponent: Exponent of the power law + :param z: Value at which to evaluate the function + :return: Regularized incomplete gamma function complement + """ if exponent == 1: return torch.exp(-z) / z if exponent == 2: return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z)) if exponent == 3: - return torch_exp1(z) + return exp1(z) if exponent == 4: return 2 * ( torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z)) ) if exponent == 5: - return torch.exp(-z) - z * torch_exp1(z) + return torch.exp(-z) - z * exp1(z) if exponent == 6: return ( (2 - 4 * z) * torch.exp(-z) diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index 9abeb823..35ff7ac7 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -137,12 +137,12 @@ def background_correction(self) -> torch.Tensor: # "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3 # and is not needed for p > 3 , so we set it to zero (see in # https://doi.org/10.48550/arXiv.2412.03281 SI section) - if self.exponent >= 3: - return torch.tensor(0.0, dtype=self.dtype, device=self.device) if self.smearing is None: raise ValueError( "Cannot compute background correction without specifying `smearing`." ) + if self.exponent >= 3: + return self.smearing * 0.0 prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2) prefac /= (3 - self.exponent) * gamma(self.exponent / 2) return prefac diff --git a/tests/lib/test_math.py b/tests/lib/test_math.py index 4ec7037c..d02196cc 100644 --- a/tests/lib/test_math.py +++ b/tests/lib/test_math.py @@ -1,8 +1,8 @@ import numpy as np +import scipy.special import torch -from scipy.special import exp1 -from torchpme.lib import torch_exp1 +from torchpme.lib import exp1 def finite_difference_derivative(func, x, h=1e-5): @@ -10,16 +10,22 @@ def finite_difference_derivative(func, x, h=1e-5): def test_torch_exp1_consistency_with_scipy(): - x = torch.rand(1000, dtype=torch.float64) - torch_result = torch_exp1(x) - scipy_result = exp1(x.numpy()) - assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6) + seed = 42 + torch.manual_seed(seed) + np.random.seed(seed) + random_tensor = torch.rand(100000) * 1000 + random_array = random_tensor.numpy() + scipy_result = scipy.special.exp1(random_array) + torch_result = exp1(random_tensor) + assert np.allclose(scipy_result, torch_result.numpy(), atol=1e-15) def test_torch_exp1_derivative(): x = torch.rand(1, dtype=torch.float64, requires_grad=True) - torch_result = torch_exp1(x) + torch_result = exp1(x) torch_result.backward() torch_exp1_prime = x.grad - finite_diff_result = finite_difference_derivative(exp1, x.detach().numpy()) + finite_diff_result = finite_difference_derivative( + scipy.special.exp1, x.detach().numpy() + ) assert np.allclose(torch_exp1_prime.numpy(), finite_diff_result, atol=1e-6)