Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch implementation of the exponential integral function #145

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

E-Rum
Copy link
Contributor

@E-Rum E-Rum commented Jan 18, 2025

This PR replaces the exponential integral function, previously implemented using a wrapped SciPy realization with Torch tensors, with a manually implemented pure Torch version. The change was made because the previous implementation involved casting tensors to NumPy and back to Torch, which reduced Torch-PME-based MLIP training time by a factor of four.

Additionally, there is a small modification to a background_correction function. As PyTorch Lightning doesn’t correctly handle newly created tensors during training, it is better to zero out existing tensors rather than creating new ones from scratch.


📚 Documentation preview 📚: https://torch-pme--145.org.readthedocs.build/en/145/

assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6)
random_tensor = torch.FloatTensor(100000).uniform_(0, 1000)
random_array = random_tensor.numpy()
scipy_result = exp1(random_array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use fixed values. We should check edge cases.

Maybe check the scipy's exp1 for the values they have tests for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the only tests I was able to find in the SciPy repository

class TestExp1:

    def test_branch_cut(self):
        assert np.isnan(sc.exp1(-1))
        assert sc.exp1(complex(-1, 0)).imag == (
            -sc.exp1(complex(-1, -0.0)).imag
        )

        assert_allclose(
            sc.exp1(complex(-1, 0)),
            sc.exp1(-1 + 1e-20j),
            atol=0,
            rtol=1e-15
        )
        assert_allclose(
            sc.exp1(complex(-1, -0.0)),
            sc.exp1(-1 - 1e-20j),
            atol=0,
            rtol=1e-15
        )

    def test_834(self):
        # Regression test for #834
        a = sc.exp1(-complex(19.9999990))
        b = sc.exp1(-complex(19.9999991))
        assert_allclose(a.imag, b.imag, atol=0, rtol=1e-15)

Since we are not interested in the complex part, I think we should leave it as it is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, okay we can leave it.

assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6)
random_tensor = torch.FloatTensor(100000).uniform_(0, 1000)
random_array = random_tensor.numpy()
scipy_result = exp1(random_array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, okay we can leave it.

Comment on lines +42 to +46
for k in range(1, 26):
r = -r * k * x_small / (k + 1.0) ** 2
e1 += r
if torch.all(torch.abs(r) <= torch.abs(e1) * 1e-15):
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Did you check if this is faster.

I am just worried by the two loops here. But maybe it is fine.

@@ -41,13 +78,13 @@ def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Te
if exponent == 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the parameter docstring of gammaincc_over_powerlaw for me 😍


@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# this implementation is inspired by the one in scipy: https://github.com/scipy/scipy/blob/55247d45469cf6203e8f562ab2b7e081ce41d0d1/scipy/special/xsf/cephes/exp10.h#L89

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants