Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with some LoRA variants when applied to bitsandbytes NF4 quantized models #7577

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions invokeai/backend/patches/layers/lora_layer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import invokeai.backend.util.logging as logger
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.util.calc_tensor_size import calc_tensors_size


Expand Down Expand Up @@ -67,8 +68,8 @@ def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_parameters[param_name]
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)
if param_weight.shape != get_param_shape(orig_param):
params[param_name] = param_weight.reshape(get_param_shape(orig_param))

return params

Expand Down
19 changes: 19 additions & 0 deletions invokeai/backend/patches/layers/param_shape_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

try:
from bitsandbytes.nn.modules import Params4bit

bnb_available: bool = True
except ImportError:
bnb_available: bool = False


def get_param_shape(param: torch.Tensor) -> torch.Size:
"""A helper function to get the shape of a parameter that handles `bitsandbytes.nn.Params4Bit` correctly."""
# Accessing the `.shape` attribute of `bitsandbytes.nn.Params4Bit` will return an incorrect result. Instead, we must
# access the `.quant_state.shape` attribute.
if bnb_available and type(param) is Params4bit: # type: ignore
quant_state = param.quant_state
if quant_state is not None:
return quant_state.shape
return param.shape
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import (
Expand Down Expand Up @@ -282,6 +283,7 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
"multiple_loras",
"concatenated_lora",
"flux_control_lora",
"single_lokr",
]
)
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
Expand Down Expand Up @@ -350,6 +352,20 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:

input = torch.randn(1, patched_in_features)
return ([(lora_layer, 0.7)], input)
elif layer_type == "single_lokr":
lokr_layer = LoKRLayer(
w1=torch.randn(rank, rank),
w1_a=None,
w1_b=None,
w2=torch.randn(out_features // rank, in_features // rank),
w2_a=None,
w2_b=None,
t2=None,
alpha=1.0,
bias=torch.randn(out_features),
)
input = torch.randn(1, in_features)
return ([(lokr_layer, 0.7)], input)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")

Expand Down
Loading