Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change TORCH_LIBRARY to TORCH_LIBRARY_FRAGMENT #1645

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .github/workflows/torchao_experimental_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,21 @@ jobs:
run: |
conda activate venv
pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104"
pip install numpy
pip install pytest
pip install numpy pytest pyyaml parameterized
USE_CPP=1 pip install .
- name: Run tests
run: |
conda activate venv
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
- name: MPS tests
run: |
conda activate venv
pushd torchao/experimental/ops/mps

echo "Building"
bash build.sh

echo "Running test"
python test/test_lowbit.py

popd
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"_linear_8bit_act_" #weight_nbit "bit_weight", \
&linear_meta<weight_nbit, true>);

TORCH_LIBRARY(torchao, m) {
TORCH_LIBRARY_FRAGMENT(torchao, m) {
DEFINE_OP(1);
DEFINE_OP(2);
DEFINE_OP(3);
Expand Down
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manuelcandales how do I test the change? Is there a script I should run?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cd ao/torchao/experimental/ops/mps
bash build.sh
python test/test_lowbit.py

Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
return B;
}

TORCH_LIBRARY(torchao, m) {
TORCH_LIBRARY_FRAGMENT(torchao, m) {
m.def("_pack_weight_1bit(Tensor W) -> Tensor");
m.def("_pack_weight_2bit(Tensor W) -> Tensor");
m.def("_pack_weight_3bit(Tensor W) -> Tensor");
Expand Down
17 changes: 10 additions & 7 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@
import torch
from parameterized import parameterized

libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)

try:
print("TRYING")
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
print("LOADING LIB")
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname))
print("AT ", libpath)
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
print("LOADED")
except Exception as e:
print("FAILED TO LOAD")
raise e
# raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
Expand Down
Loading