You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the poptorch documentation i read about the poptorch.identity_loss() function, should be an equivalent to the backward() funciton if pytorch. Is there a way to access the gradient on the input-tensor after the identity_loss() call?
I made a small minimal example, that shows that i can use either backward and access the input.grad variable or use autograd.grad to retreive the derivative for an input.
My question is: how can i retreive the gradient on the input tensor after an identity_loss call and return it as an additional return-value?
With wich argument would i have to call the identity_loss method?
Can i access this gradient even with the inferencedModel-wrapper?
Thank you for your help:)
Minimal example:
import torch
import poptorch
from torch import nn
input_dim = 2
hidden_dim = 4
output_dim = 1
model_0 = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, output_dim)).double()
class IPUModel(nn.Module):
def __init__(self):
super(IPUModel, self).__init__()
self.nn = model_0
def forward(self, in_):
out = self.nn(in_)
poptorch.identity_loss(out, reduction="none")
return out, in_.grad
class CPUModel(nn.Module):
def __init__(self):
super(CPUModel, self).__init__()
self.nn = model_0
def forward(self, in_):
out = self.nn(in_)
return out
class CPUModelBackward(nn.Module):
def __init__(self):
super(CPUModelBackward, self).__init__()
self.nn = model_0
def forward(self, in_):
out = self.nn(in_)
out.backward(torch.ones_like(out))
return out, in_.grad
model_IPU = poptorch.inferenceModel(IPUModel())
model_CPU = CPUModel()
model_CPU_back = CPUModelBackward()
x = torch.tensor(
[[1, 2], [3, 42]], dtype=torch.float64, requires_grad=True
)
y_ipu, grad_ipu = model_IPU(x)
y_cpu = model_CPU(x)
autograd_grad = torch.autograd.grad(y_cpu, x, retain_graph=True, grad_outputs=torch.ones_like(y_cpu))[0]
y, grad_input = model_CPU_back(x)
print(f"The input.grad from the IPU is unfortunatley none: {grad_ipu}")
if torch.all(autograd_grad.eq(grad_input)):
print(f"Success. This was expected. The gradient calculated by autograd.grad and by backward is {autograd_grad}")
else:
print("Error!")
Hello,
In the poptorch documentation i read about the
poptorch.identity_loss()
function, should be an equivalent to thebackward()
funciton if pytorch. Is there a way to access the gradient on the input-tensor after theidentity_loss()
call?I made a small minimal example, that shows that i can use either backward and access the input.grad variable or use autograd.grad to retreive the derivative for an input.
My question is: how can i retreive the gradient on the input tensor after an identity_loss call and return it as an additional return-value?
With wich argument would i have to call the identity_loss method?
Can i access this gradient even with the inferencedModel-wrapper?
Thank you for your help:)
Minimal example:
Background to this question:
I am currently trying to run molecular dynamics simulation with a SchNet NN on the IPU. I use the implementation of torchmd-net(https://github.com/torchmd/torchmd-net/blob/main/torchmdnet/models/torchmd_gn.py).
At the end of the model(https://github.com/torchmd/torchmd-net/blob/main/torchmdnet/models/model.py line 289) the derivatives of the outputs w.r.t the inputs are calculated with the
autograd.grad
function.This leads on the IPU to "Unsupported ops found in compiled model: [aten::_index_put_impl, aten::index_add]" Errors from the
grad()
call.The text was updated successfully, but these errors were encountered: