Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch implementation of the exponential integral function #145

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Added
#####

* Added a PyTorch implementation of the exponential integral function
* Added ``dtype`` and ``device`` for ``Calculator`` classses

Fixed
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
generate_kvectors_for_mesh,
get_ns_mesh,
)
from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1
from .math import CustomExp1, exp1, gamma, gammaincc_over_powerlaw
from .mesh_interpolator import MeshInterpolator

__all__ = [
Expand All @@ -20,5 +20,5 @@
"gamma",
"CustomExp1",
"gammaincc_over_powerlaw",
"torch_exp1",
"exp1",
]
59 changes: 51 additions & 8 deletions src/torchpme/lib/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from scipy.special import exp1
from torch.special import gammaln


Expand All @@ -15,39 +14,83 @@ def gamma(x: torch.Tensor) -> torch.Tensor:


class CustomExp1(torch.autograd.Function):
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version."""
"""
Compute the exponential integral E1(x) for x > 0.
:param input: Input tensor (x > 0)
:return: Exponential integral E1(x)
"""

@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# this implementation is inspired by the one in scipy: https://github.com/scipy/scipy/blob/55247d45469cf6203e8f562ab2b7e081ce41d0d1/scipy/special/xsf/cephes/exp10.h#L89

# Constants
SCIPY_EULER = (
0.577215664901532860606512090082402431 # Euler-Mascheroni constant
)
inf = torch.inf

# Handle case when x == 0
result = torch.full_like(input, inf)
mask = input > 0

# Compute for x <= 1
x_small = input[mask & (input <= 1)]
if x_small.numel() > 0:
e1 = torch.ones_like(x_small)
r = torch.ones_like(x_small)
for k in range(1, 26):
r = -r * k * x_small / (k + 1.0) ** 2
e1 += r
if torch.all(torch.abs(r) <= torch.abs(e1) * 1e-15):
break
Comment on lines +42 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Did you check if this is faster.

I am just worried by the two loops here. But maybe it is fine.

result[mask & (input <= 1)] = (
-SCIPY_EULER - torch.log(x_small) + x_small * e1
)

# Compute for x > 1
x_large = input[mask & (input > 1)]
if x_large.numel() > 0:
m = 20 + (80.0 / x_large).to(torch.int32)
t0 = torch.zeros_like(x_large)
for k in range(m.max(), 0, -1):
t0 = k / (1.0 + k / (x_large + t0))
t = 1.0 / (x_large + t0)
result[mask & (input > 1)] = torch.exp(-x_large) * t

return result

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
return -grad_output * torch.exp(-input) / input


def torch_exp1(input):
def exp1(input):
"""Wrapper for the custom exponential integral function."""
return CustomExp1.apply(input)


def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""Function to compute the regularized incomplete gamma function complement for integer exponents."""
"""
Function to compute the regularized incomplete gamma function complement for integer
exponents.
param exponent: Exponent of the power law
param z: Value at which to evaluate the function
return: Regularized incomplete gamma function complement
"""
if exponent == 1:
return torch.exp(-z) / z
if exponent == 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the parameter docstring of gammaincc_over_powerlaw for me 😍

return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
if exponent == 3:
return torch_exp1(z)
return exp1(z)
if exponent == 4:
return 2 * (
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
)
if exponent == 5:
return torch.exp(-z) - z * torch_exp1(z)
return torch.exp(-z) - z * exp1(z)
if exponent == 6:
return (
(2 - 4 * z) * torch.exp(-z)
Expand Down
4 changes: 2 additions & 2 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def background_correction(self) -> torch.Tensor:
# "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3
# and is not needed for p > 3 , so we set it to zero (see in
# https://doi.org/10.48550/arXiv.2412.03281 SI section)
if self.exponent >= 3:
return torch.tensor(0.0, dtype=self.dtype, device=self.device)
if self.smearing is None:
raise ValueError(
"Cannot compute background correction without specifying `smearing`."
)
if self.exponent >= 3:
return self.smearing * 0.0
prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2)
prefac /= (3 - self.exponent) * gamma(self.exponent / 2)
return prefac
Expand Down
14 changes: 9 additions & 5 deletions tests/lib/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
import torch
from scipy.special import exp1

from torchpme.lib import torch_exp1
from torchpme.lib import exp1 as torch_exp1


def finite_difference_derivative(func, x, h=1e-5):
return (func(x + h) - func(x - h)) / (2 * h)


def test_torch_exp1_consistency_with_scipy():
x = torch.rand(1000, dtype=torch.float64)
torch_result = torch_exp1(x)
scipy_result = exp1(x.numpy())
assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6)
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random_tensor = torch.rand(100000) * 1000
random_array = random_tensor.numpy()
scipy_result = exp1(random_array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use fixed values. We should check edge cases.

Maybe check the scipy's exp1 for the values they have tests for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the only tests I was able to find in the SciPy repository

class TestExp1:

    def test_branch_cut(self):
        assert np.isnan(sc.exp1(-1))
        assert sc.exp1(complex(-1, 0)).imag == (
            -sc.exp1(complex(-1, -0.0)).imag
        )

        assert_allclose(
            sc.exp1(complex(-1, 0)),
            sc.exp1(-1 + 1e-20j),
            atol=0,
            rtol=1e-15
        )
        assert_allclose(
            sc.exp1(complex(-1, -0.0)),
            sc.exp1(-1 - 1e-20j),
            atol=0,
            rtol=1e-15
        )

    def test_834(self):
        # Regression test for #834
        a = sc.exp1(-complex(19.9999990))
        b = sc.exp1(-complex(19.9999991))
        assert_allclose(a.imag, b.imag, atol=0, rtol=1e-15)

Since we are not interested in the complex part, I think we should leave it as it is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, okay we can leave it.

torch_result = torch_exp1(random_tensor)
assert np.allclose(scipy_result, torch_result.numpy(), atol=1e-15)


def test_torch_exp1_derivative():
Expand Down
Loading