Skip to content

Commit

Permalink
chore: improve link-checker, fix Tracer issue and remove solved issues
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia authored Nov 7, 2024
1 parent 4829444 commit e396438
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 81 deletions.
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -704,14 +704,16 @@ check_links:
@# To avoid some issues with priviledges and linkcheckmd
find docs/ -name "*.md" -type f | xargs chmod +r

@# Run linkcheck on mardown files. It is mainly used for web links
@# Run linkcheck on markdown files to check only local files
poetry run python -m linkcheckmd docs -local
poetry run python -m linkcheckmd README.md

@# Check that relative links in mardown files are targeting existing files
@# Check that web links are functional or not broken
poetry run python ./script/make_utils/check_links_with_agent.py README.md --verbose

@# Check that relative links in markdown files are targeting existing files
poetry run python ./script/make_utils/local_link_check.py

@# Check that links to mardown headers in mardown files are targeting existing headers
@# Check that links to markdown headers in markdown files are targeting existing headers
poetry run python ./script/make_utils/check_headers.py

@# For weblinks and internal references
Expand Down
4 changes: 0 additions & 4 deletions docs/guides/hybrid-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ For inference with the `HybridFHEModel` instance, `hybrid_model`, call the regul
hybrid_model(torch.randn((dim, )))
```

<!-- Add a forward method to hybridfhemodel?-->

<!-- FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4579-->

When calling `HybridFHEModel`, it handles all the necessary intermediate steps for each model part deployed remotely, including:

- Quantizing the data.
Expand Down
94 changes: 94 additions & 0 deletions script/make_utils/check_links_with_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Check external web links.
Note: We avoid using `poetry run python -m linkcheckmd README.md` because
some servers restrict access when they detect requests from chatbots.
"""

import argparse
import re
import sys
from pathlib import Path
from typing import List

import requests


def check_links(file_path: Path, verbose: bool) -> List[str]:
"""Check the content of a markdown file for dead links.
Args:
file_path (Path): The path to the file.
verbose (bool): Enable verbose output.
Returns:
List[str]: a list of errors (dead-links) found.
"""

headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
)
}

broken_links = []

# Read the file
content = file_path.read_text(encoding="utf-8")

# Use a regular expression to extract all links starting with https up to the next whitespace.
raw_links = re.findall(r"https://\S+", content)

# Clean the links by removing everything after any of these characters: '", ], }, ), >'
links_to_check = [re.split(r'["\]\}\)>]', link)[0] for link in raw_links]

# Check each link
for link in links_to_check:
try:
response = requests.get(link, headers=headers, timeout=10)
if response.status_code == 200:
status_message = f"OK: {link}"
else:
status_message = f"Failed: {link} (Status Code: {response.status_code})"
broken_links.append(status_message)

except requests.exceptions.RequestException as e:
# Extract only the relevant part of the error message
status_message = (
f"Failed: {link} ({e.__class__.__name__}: {str(e).rsplit(':', maxsplit=1)[-1]})"
)
broken_links.append(status_message)

if verbose:
print(status_message)

return broken_links


def main():
"""Main function"""

# Set up argument parsing
parser = argparse.ArgumentParser(description="Check web links in a file.")
parser.add_argument("filename", help="The path to the file to check")
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
args = parser.parse_args()

if args.verbose:
print(f"checking external links {args.filename}")

# Create a Path object for the file
root = Path(".")
file_path = root / args.filename

broken_links = check_links(file_path, args.verbose)

# Exit with status code 1 if there are broken links
if broken_links:
print("\nBroken links:")
sys.exit("\n".join(broken_links))


if __name__ == "__main__":
main()
10 changes: 7 additions & 3 deletions src/concrete/ml/onnx/onnx_impl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ def numpy_onnx_pad(
# the values on the edges to the input zero_point, which corresponds
# to the real-axis 0
if int_only:
# Work in integer Concrete mode
x_pad = fhe_ones(tuple(padded_shape)) * numpy.int64(pad_value)
if isinstance(x_pad, Tracer):
# Quantized execution: integer mode with tracing
x_pad = fhe_ones(tuple(padded_shape)) * numpy.int64(pad_value)
else:
# Quantized execution: integer mode without tracing
x_pad = numpy.ones(padded_shape, dtype=numpy.int64) * pad_value
else:
# Floating point mode
# Calibration mode: floating-point padding for non-quantized execution
x_pad = numpy.ones(padded_shape, dtype=numpy.float32) * pad_value
assert isinstance(x_pad, (numpy.ndarray, Tracer))

Expand Down
16 changes: 16 additions & 0 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,19 @@ def forward(self, x):
x = self.relu(x)
x = self.linear(x)
return x


class AllZeroCNN(CNNOther):
"""A CNN class that has all zero weights and biases."""

def __init__(self, input_output, activation_function):
super().__init__(input_output, activation_function)

for module in self.modules():
# assert m.bias is not None
# Disable mypy as it properly detects that module's bias term is None end therefore
# does not have a `data` attribute but fails to take into consideration the fact
# that `torch.nn.init.constant_` actually handles such a case
if isinstance(module, (nn.Conv2d, nn.Linear)):
torch.nn.init.constant_(module.weight.data, 0)
torch.nn.init.constant_(module.bias.data, 0) # type: ignore[union-attr]
4 changes: 1 addition & 3 deletions tests/quantization/test_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,7 @@ def test_clip_op(
ARITH_N_BITS_LIST = [20, 16, 8]


# This test is a known flaky (in particular, with the QuantizedDiv operator)
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4563
@pytest.mark.flaky
# This test was known to be flaky, in particular, with the QuantizedDiv operator (see issue 4563).
@pytest.mark.parametrize(
"operator, supports_enc_with_enc, r2_threshold_bits",
[
Expand Down
99 changes: 32 additions & 67 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from concrete.ml.pytest.torch_models import (
FC,
AddNet,
AllZeroCNN,
BranchingGemmModule,
BranchingModule,
CNNGrouped,
Expand Down Expand Up @@ -983,76 +984,40 @@ def test_qat_import_bits_check(default_configuration):
)


def test_qat_import_check(default_configuration, check_is_good_execution_for_cml_vs_circuit):
@pytest.mark.parametrize(
"model, input_shape, input_output",
[
# This model is trying to import a network that is QAT (has a quantizer in the graph)
# but the import bit-width is wrong (mismatch between bit-width specified in training
# and the bit-width specified during import). For NNs that are not built with Brevitas
# the bit-width must be manually specified and is used to infer quantization parameters.
(partial(StepFunctionPTQ, n_bits=6, disable_bit_check=True), None, 10),
# This network may look like QAT but it just zeros all inputs
(AllZeroCNN, (1, 7, 7), 1),
# This second case is a network that is not QAT but is being imported as a QAT network
(CNNOther, (1, 7, 7), 1),
],
)
def test_qat_import_check(
model,
input_shape,
input_output,
default_configuration,
check_is_good_execution_for_cml_vs_circuit,
):
"""Test two cases of custom (non brevitas) NNs where importing as QAT networks should fail."""
qat_bits = 4

simulate = True

error_message_pattern = "Error occurred during quantization aware training.*"

# This first test is trying to import a network that is QAT (has a quantizer in the graph)
# but the import bit-width is wrong (mismatch between bit-width specified in training
# and the bit-width specified during import). For NNs that are not built with Brevitas
# the bit-width must be manually specified and is used to infer quantization parameters.
with pytest.raises(ValueError, match=error_message_pattern):
with pytest.raises(ValueError, match="Error occurred during quantization aware training.*"):
compile_and_test_torch_or_onnx(
10,
partial(StepFunctionPTQ, n_bits=6, disable_bit_check=True),
nn.ReLU,
qat_bits,
default_configuration,
simulate,
False,
check_is_good_execution_for_cml_vs_circuit,
)

input_shape = (1, 7, 7)
input_output = input_shape[0]

# The second case is a network that is not QAT but is being imported as a QAT network
with pytest.raises(ValueError, match=error_message_pattern):
compile_and_test_torch_or_onnx(
input_output,
CNNOther,
nn.ReLU,
qat_bits,
default_configuration,
simulate,
False,
check_is_good_execution_for_cml_vs_circuit,
input_shape=input_shape,
)

class AllZeroCNN(CNNOther):
"""A CNN class that has all zero weights and biases."""

def __init__(self, input_output, activation_function):
super().__init__(input_output, activation_function)

for module in self.modules():
# assert m.bias is not None
# Disable mypy as it properly detects that module's bias term is None end therefore
# does not have a `data` attribute but fails to take into consideration the fact
# that `torch.nn.init.constant_` actually handles such a case
if isinstance(module, (nn.Conv2d, nn.Linear)):
torch.nn.init.constant_(module.weight.data, 0)
torch.nn.init.constant_(module.bias.data, 0) # type: ignore[union-attr]

input_shape = (1, 7, 7)
input_output = input_shape[0]

# A network that may look like QAT but it just zeros all inputs
with pytest.raises(ValueError, match=error_message_pattern):
compile_and_test_torch_or_onnx(
input_output,
AllZeroCNN,
nn.ReLU,
qat_bits,
default_configuration,
simulate,
False,
check_is_good_execution_for_cml_vs_circuit,
input_output_feature=input_output,
model_class=model,
activation_function=nn.ReLU,
qat_bits=4,
default_configuration=default_configuration,
simulate=True,
is_onnx=False,
check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit,
# For non-null input_shape values, input_output is input_shape[0]
input_shape=input_shape,
)

Expand Down

0 comments on commit e396438

Please sign in to comment.