-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PyTorch implementation of the exponential integral function #145
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
import torch | ||
from scipy.special import exp1 | ||
from torch.special import gammaln | ||
|
||
|
||
|
@@ -15,39 +14,83 @@ def gamma(x: torch.Tensor) -> torch.Tensor: | |
|
||
|
||
class CustomExp1(torch.autograd.Function): | ||
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version.""" | ||
""" | ||
Compute the exponential integral E1(x) for x > 0. | ||
:param input: Input tensor (x > 0) | ||
:return: Exponential integral E1(x) | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, input): | ||
ctx.save_for_backward(input) | ||
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy() | ||
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype) | ||
|
||
# Constants | ||
SCIPY_EULER = ( | ||
0.577215664901532860606512090082402431 # Euler-Mascheroni constant | ||
) | ||
inf = torch.inf | ||
|
||
# Handle case when x == 0 | ||
result = torch.full_like(input, inf) | ||
mask = input > 0 | ||
|
||
# Compute for x <= 1 | ||
x_small = input[mask & (input <= 1)] | ||
if x_small.numel() > 0: | ||
e1 = torch.ones_like(x_small) | ||
r = torch.ones_like(x_small) | ||
for k in range(1, 26): | ||
r = -r * k * x_small / (k + 1.0) ** 2 | ||
e1 += r | ||
if torch.all(torch.abs(r) <= torch.abs(e1) * 1e-15): | ||
break | ||
Comment on lines
+42
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Did you check if this is faster. I am just worried by the two loops here. But maybe it is fine. |
||
result[mask & (input <= 1)] = ( | ||
-SCIPY_EULER - torch.log(x_small) + x_small * e1 | ||
) | ||
|
||
# Compute for x > 1 | ||
x_large = input[mask & (input > 1)] | ||
if x_large.numel() > 0: | ||
m = 20 + (80.0 / x_large).to(torch.int32) | ||
t0 = torch.zeros_like(x_large) | ||
for k in range(m.max(), 0, -1): | ||
t0 = k / (1.0 + k / (x_large + t0)) | ||
t = 1.0 / (x_large + t0) | ||
result[mask & (input > 1)] = torch.exp(-x_large) * t | ||
|
||
return result | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
(input,) = ctx.saved_tensors | ||
return -grad_output * torch.exp(-input) / input | ||
|
||
|
||
def torch_exp1(input): | ||
def exp1(input): | ||
"""Wrapper for the custom exponential integral function.""" | ||
return CustomExp1.apply(input) | ||
|
||
|
||
def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | ||
"""Function to compute the regularized incomplete gamma function complement for integer exponents.""" | ||
""" | ||
Function to compute the regularized incomplete gamma function complement for integer | ||
exponents. | ||
param exponent: Exponent of the power law | ||
param z: Value at which to evaluate the function | ||
return: Regularized incomplete gamma function complement | ||
""" | ||
if exponent == 1: | ||
return torch.exp(-z) / z | ||
if exponent == 2: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add the parameter docstring of |
||
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z)) | ||
if exponent == 3: | ||
return torch_exp1(z) | ||
return exp1(z) | ||
if exponent == 4: | ||
return 2 * ( | ||
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z)) | ||
) | ||
if exponent == 5: | ||
return torch.exp(-z) - z * torch_exp1(z) | ||
return torch.exp(-z) - z * exp1(z) | ||
if exponent == 6: | ||
return ( | ||
(2 - 4 * z) * torch.exp(-z) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,18 +2,22 @@ | |
import torch | ||
from scipy.special import exp1 | ||
|
||
from torchpme.lib import torch_exp1 | ||
from torchpme.lib import exp1 as torch_exp1 | ||
|
||
|
||
def finite_difference_derivative(func, x, h=1e-5): | ||
return (func(x + h) - func(x - h)) / (2 * h) | ||
|
||
|
||
def test_torch_exp1_consistency_with_scipy(): | ||
x = torch.rand(1000, dtype=torch.float64) | ||
torch_result = torch_exp1(x) | ||
scipy_result = exp1(x.numpy()) | ||
assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6) | ||
seed = 42 | ||
torch.manual_seed(seed) | ||
np.random.seed(seed) | ||
random_tensor = torch.rand(100000) * 1000 | ||
random_array = random_tensor.numpy() | ||
scipy_result = exp1(random_array) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can use fixed values. We should check edge cases. Maybe check the scipy's exp1 for the values they have tests for. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are the only tests I was able to find in the SciPy repository
Since we are not interested in the complex part, I think we should leave it as it is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, okay we can leave it. |
||
torch_result = torch_exp1(random_tensor) | ||
assert np.allclose(scipy_result, torch_result.numpy(), atol=1e-15) | ||
|
||
|
||
def test_torch_exp1_derivative(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.