Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support positional encoding interpolation in Zoo models #601

Open
korotaS opened this issue Jun 17, 2024 · 3 comments
Open

Support positional encoding interpolation in Zoo models #601

korotaS opened this issue Jun 17, 2024 · 3 comments
Assignees
Labels

Comments

@korotaS
Copy link
Contributor

korotaS commented Jun 17, 2024

Hi! I am using this example to train a ConcatSiamese model and if I use an extractor from example (vits16_dino) - it runs ok but if I use different model, for example vitb32_unicom I get an error. Here is an example for reproducing:

device = 'cpu'
extractor = ViTUnicomExtractor.from_pretrained("vitb32_unicom").to(device)
transforms, _ = get_transforms_for_pretrained("vitb32_unicom")
pairwise_model = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100], device=device)

out = pairwise_model(x1=torch.rand(2, 3, 224, 224), x2=torch.rand(2, 3, 224, 224))

And here is the last traceback item:

File /home/korotas/projects/open-metric-learning/oml/models/vit_unicom/external/vision_transformer.py:181, in VisionTransformer.forward_features(self, x)
    179 B = x.shape[0]
    180 x = self.patch_embed(x)
--> 181 x = x + self.pos_embed
    182 for func in self.blocks:
    183     x = func(x)

RuntimeError: The size of tensor a (98) must match the size of tensor b (49) at non-singleton dimension 1

I think the problem is that in vits16_dino there is an interpolation before positional embedding, so the tensor is downsampled into preferred shape:

def interpolate_pos_encoding(self, x, w: int, h: int):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode="bicubic",
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

But in vitb32_unicom there is no such interpolation, so the input tensor is 2 times bigger than positional encoding expects, so we need to manually replace some layers that depend on number of patches (model.pos_embed and model.feature[0]).

I think that this information needs to be added in docs or perhaps handled in ViTUnicomExtractor or in ConcatSiamese with some kind of warning. Also, this error reproduces with all other ViTUnicomExtractor and ViTCLIPExtractor models.

@AlekseySh AlekseySh self-assigned this Jun 17, 2024
@AlekseySh
Copy link
Contributor

AlekseySh commented Jun 17, 2024

@korotaS thank you for the report!

You are absolutely right. Note, we haven't changed the original implementations (everything which is under external folder is a copy-paste). The problem is indeed in the interpolation.

As for the solution on the library side: I think it may be a compromise.
Let's add an interpolation, but also raise warning if image size assumes we need to apply this interpolation. In other words, we warn user if code works a bit not as expected in the original implementation.

There is also a problem we need to solve: the check will be placed in forward(), but we want to avoid thousands of warning, so, we only need to warn user once. We can use a decorator for this purpose. Here is a draft:

import warnings

def warn_if_even(func):
    def wrapper(n):
        if n % 2 == 0:
            warnings.warn(f"Input value {n} is an even number")
        return func(n)
    return wrapper

@warn_if_even
def my_function(n):
    # Your function logic here
    print(f"Function called with argument: {n}")

# Example of calling the function with different values
for i in range(1, 6):
    my_function(i)

Do you want to make this contribution to the library, @korotaS ? I would greatly appreciate it :)

@AlekseySh AlekseySh added experiments good first issue Good for newcomers and removed question labels Jun 17, 2024
@AlekseySh AlekseySh changed the title ConcatSiamese model error with non-DINO extractors Support positional encoding interpolation in Zoo models Jun 17, 2024
@korotaS
Copy link
Contributor Author

korotaS commented Jun 17, 2024

Do you think that interpolating is better than changing some layers? I know that it is not a good idea to replace trained layers with non-trained ones but it may be that they can learn quickly (especially positional encoding).

As for this:

Do you want to make this contribution to the library, @korotaS?

I think I can, maybe not today and not tomorrow but I will let you know asap.

@AlekseySh
Copy link
Contributor

@korotaS

In Dino they also trained on some set of fixed images sizes, but on inference time they allow to interpolate

We have a simple way to check if everything is fine. Just run validation on any of ours benchmarks with im_size=360 and compare the results with the ones provided in the zoo table.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: To do
Development

No branches or pull requests

2 participants