-
-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactors TryMatmuls for better error messages when dimensions don't match #681
base: main
Are you sure you want to change the base?
Conversation
Okay have you tried having a trait at the Dim level that does the dyn vs const? trait DimEq<Rhs> {
fn assert_dim_eq(...) { ... }
}
impl DimEq<usize> for usize { ... }
impl DimEq<Const<N>> for Const<M> { ... }
impl DimEq<usize> for Const<M> { ... }
impl DimEq<Const<M>> for usize { ... } That might let us combine the static vs dynamic traits you have now? impl<M: Dim, K1: Dim, K2: Dim, N: Dim> MulDimCheck<(M, K1)> for (K2, N)
where K1: DimEq<K2>
{ ... } |
By having only one I tried only having the static dim checks (which is what you are elluding to?), but attention complains even if we omit the dynamic dim checks, because we still need variations of TryMatMuls now that K is split into K1 and K2 (when K1, K2 are not Having the static and dynamic variants isn't so much of a problem imo, because they can really serve a purpose: either make a compile or a runtime check. The main problem is that, these traits need to be implemented for dynamic shapes multiple times because of what happens inside attention. In particular, To make the above clearer on your example, one would then need to implement both Even if I omit the dynamic implementations and checks, I still need to implement the dynamic matmuls. The problem remains the same: rust complains about conflicts in the implementations, which is mostly relevant for the rank3 and rank4 matmuls inside attention, so this is where one needs to split the traits to allow for implementations of all arising combinations of If we could force somehow, eg the output of |
Can you expand? |
This is what I also thought/hoped, but it seems like if rust thinks an object has usize as a dimension, then it is a different thing, it isn't a Eg in the code I have pushed in the PR, I have already implemented: impl<B: Dim, M: Dim, LeftK: Dim, RightK: Dim, N: Dim, E: Dtype, D, T, R>
TryStaticMatMul<Tensor<(B, RightK, N), E, D, R>> for Tensor<(B, M, LeftK), E, D, T>
where ... and impl<B: Dim, S1: Dim, S2: Dim, E: Dtype, D, T, R> TryDynamicMatMul<Tensor<(B, usize, S2), E, D, R>>
for Tensor<(B, S1, usize), E, D, T>
where ... and
So, combinations of let tokens = weights.try_dynamic_matmul(v)?; The error message for it:
Another dimension combination I haven't implemented ( I wish I am doing something dumb to be honest or misunderstanding some rust details. But not sure what that might be. My current insights are summarized in the last paragraph of my previous comment, if that helps. |
Here's the diff of the fixes! |
I haven't synced the branch yet with all the recent commits, but will of course do if this works out in the end.
So, I tried to move more of the dynamic dim checks into the respective traits before pushing, and now I need to split the traits and ops even more for stuff to work inside
mha.rs
. So, unless I "turn off" the dynamic checks as in e.g.where (B, S1, usize): MulDynamicDimCheck<(B, usize, S2)>;
(and apart from tests inside mod.rs needing to be fixed in any case), multiheaded attentions still spits out errors.Just so that you don't get completely lost in the mess that I have pushed, a quick summary of what I am doing:
TryMatMul<M,K,N>
from assuming that the crucial dimK
is always common between the two tensors, intoTryMatMul<M, LeftK, RightK, N>
, so that we can make the dim checks with the respective trait. This would already work quite good, but with dynamic dims we need to split into static or dynamic traits, depending on whether K is static or dynamic (the dynamic part mostly because of transformers).we get conflict errors:
I have also probably forgotten to implement some stuff while lost in the splitting of traits. All the implementations that I have commented out more or less need new variations of the traits for them to work.
It would be great if we could find a nicer (and more maintainable way) to get over this, because the errors now look quite good and informative (already tested for the static cases and some dynamic cases).