From 7b0e13ccf49196e41625df61ed2a3ae7fa82a6d1 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Wed, 15 Jan 2025 22:42:46 +0000 Subject: [PATCH 01/22] Implement prototype for revised type reflection system. --- revised_type_reflection_system_prototype.py | 190 ++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 revised_type_reflection_system_prototype.py diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py new file mode 100644 index 0000000000..3328496578 --- /dev/null +++ b/revised_type_reflection_system_prototype.py @@ -0,0 +1,190 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. + +import inspect +import numpy as np +import cupy as cp +from dataclasses import dataclass +from functools import wraps +from collections import OrderedDict +from typing import Any + +## Everything related to API boundary determination + +global_api_counter : int = 0 + +def is_api_internal(): + return global_api_counter > 1 + + +def api_boundary(func): + + @wraps(func) + def inner(*args, **kwargs): + global global_api_counter + global_api_counter += 1 + try: + return func(*args, **kwargs) + finally: + global_api_counter -= 1 + + return inner + + +## CumlArray + +class CumlArray: + + def __init__(self, data): + self.data = data + + def to_output(self, output_type: str): + match output_type: + case "numpy": + return np.asarray(self.data) + case "cupy": + return cp.asarray(self.data) + case _: + raise TypeError(f"Unknown output_type '{output_type}'.") + + +## CumlArrayDescriptor + +class CumlArrayDescriptor: + + def __init__(self, order="K"): + self.order = order + + def __set_name__(self, owner, name): + self.name = name + + def __set__(self, obj, value): + # Just save the provided value as CumlArray + setattr(obj, f"_{self.name}_value", CumlArray(value)) + + def __get__(self, obj, objtype=None): + value = getattr(obj, f"_{self.name}_value") + if global_api_counter > 0: + return value + else: + output_type = _get_output_type(obj) + return value.to_output(output_type) + + +## Type reflection + +global_output_type = None + +def determine_array_type(value) -> str: + if isinstance(value, CumlArray): + return "cuml" + elif isinstance(value, np.ndarray): + return "numpy" + elif isinstance(value, cp.ndarray): + return "cupy" + else: + return ValueError(f"Unknown array type: {type(value)}") + +def _set_output_type(obj: Any, output_type: str): + setattr(obj, "_output_type", output_type) + +def _get_output_type(obj: Any): + if global_output_type is None: + return getattr(obj, "_output_type", None) + else: + return global_output_type + + +class set_output_type: # decorator + + def __init__(self, arg_name: str): + self.arg_name = arg_name + + def __call__(self, func): + sig = inspect.signature(func) + + @api_boundary + def inner(obj, *args, **kwargs): + if not is_api_internal(): + bound_args = sig.bind(obj, *args, **kwargs) + bound_args.apply_defaults() + + arg_value = bound_args.arguments.get(self.arg_name) + arg_type = determine_array_type(arg_value) + _set_output_type(obj, arg_type) + + return func(obj, *args, **kwargs) + + return inner + + +def to_output_type(return_value, output_type: str): + """Convert CumlArray and containers of CumlArray.""" + if type(return_value) is CumlArray: + return return_value.to_output(output_type) + elif type(return_value) is tuple: + return tuple(to_output_type(item) for item in return_value) + else: + return return_value + + +def convert_cuml_arrays(func): # decorator + + @wraps(func) + @api_boundary + def inner(obj, *args, **kwargs): + ret = func(obj, *args, **kwargs) + if is_api_internal(): + return ret + else: + output_type = _get_output_type(obj) + return to_output_type(ret, output_type) + + return inner + +## Example estimator implementation + +class MinimalLinearRegression: + + coef_ = CumlArrayDescriptor() + intercept_ = CumlArrayDescriptor() + + @set_output_type("X") + def fit(self, X, y): + X = CumlArray(X).to_output("numpy") + X_design = np.hstack([np.ones((X.shape[0], 1)), X]) + + # Compute coefficients using normal equation + weights = np.linalg.pinv(X_design.T @ X_design) @ X_design.T @ y + + # Separate intercept and coefficients + self.intercept_ = weights[0] + self.coef_ = weights[1:] + + return self + + @convert_cuml_arrays + def predict(self, X): + X = CumlArray(X).to_output("numpy") + y = X @ self.coef_.to_output("numpy") + self.intercept_.to_output("numpy") + return CumlArray(y) + + +def test(): + # Example usage + from sklearn.datasets import make_regression + + # Create synthetic data + X, y = make_regression(n_samples=20, n_features=2, noise=0.1, random_state=42) + + # Instantiate and train the estimator + model = MinimalLinearRegression() + model.fit(X, y) + + # Make predictions + predictions = model.predict(X) + print("Predictions:", predictions, type(predictions)) + print("coef", model.coef_, type(model.coef_)) + + +if __name__ == "__main__": + test() From ac3e5f3fc137bf08758feb4d7e00587e5541e788 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 17:16:54 +0000 Subject: [PATCH 02/22] fixup! Implement prototype for revised type reflection system. --- revised_type_reflection_system_prototype.py | 79 ++++++++++++++++----- 1 file changed, 61 insertions(+), 18 deletions(-) diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py index 3328496578..ce10be72ed 100644 --- a/revised_type_reflection_system_prototype.py +++ b/revised_type_reflection_system_prototype.py @@ -7,6 +7,7 @@ from functools import wraps from collections import OrderedDict from typing import Any +from collections.abc import Sequence ## Everything related to API boundary determination @@ -40,12 +41,23 @@ def __init__(self, data): def to_output(self, output_type: str): match output_type: case "numpy": - return np.asarray(self.data) + if isinstance(self.data, cp.ndarray): + return self.data.get() + else: + return np.asarray(self.data) case "cupy": return cp.asarray(self.data) case _: raise TypeError(f"Unknown output_type '{output_type}'.") + def to_device_array(self) -> cp.ndarray: + return self.to_output("cupy") + + +def as_cuml_array(X) -> CumlArray: + """Wraps array X in CumlArray container.""" + return CumlArray(X) + ## CumlArrayDescriptor @@ -59,9 +71,11 @@ def __set_name__(self, owner, name): def __set__(self, obj, value): # Just save the provided value as CumlArray - setattr(obj, f"_{self.name}_value", CumlArray(value)) + setattr(obj, f"_{self.name}_value", as_cuml_array(value)) def __get__(self, obj, objtype=None): + # Return either the original value for internal access or convert to the + # desired output type. value = getattr(obj, f"_{self.name}_value") if global_api_counter > 0: return value @@ -75,6 +89,7 @@ def __get__(self, obj, objtype=None): global_output_type = None def determine_array_type(value) -> str: + """Utility function to identify the array type.""" if isinstance(value, CumlArray): return "cuml" elif isinstance(value, np.ndarray): @@ -87,14 +102,24 @@ def determine_array_type(value) -> str: def _set_output_type(obj: Any, output_type: str): setattr(obj, "_output_type", output_type) -def _get_output_type(obj: Any): +def _get_output_type(obj: Any) -> str: if global_output_type is None: return getattr(obj, "_output_type", None) else: return global_output_type -class set_output_type: # decorator +class set_output_type: + """Set a object's output_type based on a function argument type. + + Example: + + @set_output_type("X") + def fit(self, X, y): + ... + + Sets the output_type of self to the type of the X argument. + """ def __init__(self, arg_name: str): self.arg_name = arg_name @@ -117,17 +142,18 @@ def inner(obj, *args, **kwargs): return inner -def to_output_type(return_value, output_type: str): +def _to_output_type(obj, output_type: str): """Convert CumlArray and containers of CumlArray.""" - if type(return_value) is CumlArray: - return return_value.to_output(output_type) - elif type(return_value) is tuple: - return tuple(to_output_type(item) for item in return_value) + if isinstance(obj, CumlArray): + return obj.to_output(output_type) + elif isinstance(obj, Sequence) and not isinstance(obj, str): + return type(obj)(_to_output_type(item) for item in obj) else: - return return_value + return obj def convert_cuml_arrays(func): # decorator + """Cuml arrays in method return value are converted.""" @wraps(func) @api_boundary @@ -137,7 +163,7 @@ def inner(obj, *args, **kwargs): return ret else: output_type = _get_output_type(obj) - return to_output_type(ret, output_type) + return _to_output_type(ret, output_type) return inner @@ -148,24 +174,41 @@ class MinimalLinearRegression: coef_ = CumlArrayDescriptor() intercept_ = CumlArrayDescriptor() - @set_output_type("X") - def fit(self, X, y): - X = CumlArray(X).to_output("numpy") - X_design = np.hstack([np.ones((X.shape[0], 1)), X]) + # Private methods should not be at the API boundary and should + # not use the @set_output_type decorator. + + def _fit_on_device(self, X: cp.ndarray, y: cp.ndarray): + X_design = cp.hstack([cp.ones((X.shape[0], 1)), X]) # Compute coefficients using normal equation - weights = np.linalg.pinv(X_design.T @ X_design) @ X_design.T @ y + weights = cp.linalg.pinv(X_design.T @ X_design) @ X_design.T @ y # Separate intercept and coefficients self.intercept_ = weights[0] self.coef_ = weights[1:] + @set_output_type("X") + def fit(self, X, y): + # The implementation here is device specific. We delay the conversion to + # CumlArray and then device array to the latest possible moment. + X, y = as_cuml_array(X), as_cuml_array(y) + self._fit_on_device(X.to_device_array(), y.to_device_array()) + return self + def _predict_on_device(self, X: cp.ndarray) -> cp.ndarray: + # This is an API internal method, the array descriptor will not(!) + # perform an automatic conversion. + return X @ self.coef_.to_device_array() + self.intercept_.to_device_array() + @convert_cuml_arrays def predict(self, X): - X = CumlArray(X).to_output("numpy") - y = X @ self.coef_.to_output("numpy") + self.intercept_.to_output("numpy") + y = self._predict_on_device(as_cuml_array(X).to_device_array()) + + # By returning the result within the CumlArray container in a function + # at the API boundary decorated with @convert_cuml_arrays, we ensure + # that the return value is automatically converted to reflect the desired + # type. return CumlArray(y) From 5f8d9f694f46350de9c55d0dd2f56f444e928e1d Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 18:06:53 +0000 Subject: [PATCH 03/22] fixup! Implement prototype for revised type reflection system. --- revised_type_reflection_system_prototype.py | 23 +++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py index ce10be72ed..ad7c2f39b6 100644 --- a/revised_type_reflection_system_prototype.py +++ b/revised_type_reflection_system_prototype.py @@ -3,11 +3,10 @@ import inspect import numpy as np import cupy as cp -from dataclasses import dataclass from functools import wraps -from collections import OrderedDict from typing import Any from collections.abc import Sequence +from contextlib import contextmanager ## Everything related to API boundary determination @@ -88,6 +87,17 @@ def __get__(self, obj, objtype=None): global_output_type = None +@contextmanager +def override_output_type(output_type: str): + global global_output_type + try: + previous_output_type = global_output_type + global_output_type = output_type + yield + finally: + global_output_type = previous_output_type + + def determine_array_type(value) -> str: """Utility function to identify the array type.""" if isinstance(value, CumlArray): @@ -212,7 +222,7 @@ def predict(self, X): return CumlArray(y) -def test(): +def example_workflow(): # Example usage from sklearn.datasets import make_regression @@ -228,6 +238,11 @@ def test(): print("Predictions:", predictions, type(predictions)) print("coef", model.coef_, type(model.coef_)) + with override_output_type("cupy"): + assert isinstance(model.coef_, cp.ndarray) + + assert isinstance(model.coef_, np.ndarray) + if __name__ == "__main__": - test() + example_workflow() From 6b9895b26cc2a64ad9c0d7f76bd5c3f5f1507675 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 18:08:42 +0000 Subject: [PATCH 04/22] fixup! Implement prototype for revised type reflection system. --- revised_type_reflection_system_prototype.py | 23 ++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py index ad7c2f39b6..9aa961943b 100644 --- a/revised_type_reflection_system_prototype.py +++ b/revised_type_reflection_system_prototype.py @@ -1,16 +1,18 @@ # Copyright (c) 2025, NVIDIA CORPORATION. import inspect -import numpy as np -import cupy as cp -from functools import wraps -from typing import Any from collections.abc import Sequence from contextlib import contextmanager +from functools import wraps +from typing import Any + +import cupy as cp +import numpy as np ## Everything related to API boundary determination -global_api_counter : int = 0 +global_api_counter: int = 0 + def is_api_internal(): return global_api_counter > 1 @@ -32,6 +34,7 @@ def inner(*args, **kwargs): ## CumlArray + class CumlArray: def __init__(self, data): @@ -39,7 +42,7 @@ def __init__(self, data): def to_output(self, output_type: str): match output_type: - case "numpy": + case "numpy": if isinstance(self.data, cp.ndarray): return self.data.get() else: @@ -60,6 +63,7 @@ def as_cuml_array(X) -> CumlArray: ## CumlArrayDescriptor + class CumlArrayDescriptor: def __init__(self, order="K"): @@ -87,6 +91,7 @@ def __get__(self, obj, objtype=None): global_output_type = None + @contextmanager def override_output_type(output_type: str): global global_output_type @@ -109,9 +114,11 @@ def determine_array_type(value) -> str: else: return ValueError(f"Unknown array type: {type(value)}") + def _set_output_type(obj: Any, output_type: str): setattr(obj, "_output_type", output_type) + def _get_output_type(obj: Any) -> str: if global_output_type is None: return getattr(obj, "_output_type", None) @@ -130,7 +137,7 @@ def fit(self, X, y): Sets the output_type of self to the type of the X argument. """ - + def __init__(self, arg_name: str): self.arg_name = arg_name @@ -177,8 +184,10 @@ def inner(obj, *args, **kwargs): return inner + ## Example estimator implementation + class MinimalLinearRegression: coef_ = CumlArrayDescriptor() From fb389c98d3a80ef31ef51af9f41d23006532b43d Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 18:43:50 +0000 Subject: [PATCH 05/22] fixup! Implement prototype for revised type reflection system. --- revised_type_reflection_system_prototype.py | 39 +++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py index 9aa961943b..c74d9f4240 100644 --- a/revised_type_reflection_system_prototype.py +++ b/revised_type_reflection_system_prototype.py @@ -185,6 +185,31 @@ def inner(obj, *args, **kwargs): return inner +class convert_cuml_arrays_to_type_of: + def __init__(self, arg_name: str): + self.arg_name = arg_name + + def __call__(self, func): + sig = inspect.signature(func) + + @wraps(func) + @api_boundary + def inner(*args, **kwargs): + ret = func(*args, **kwargs) + if is_api_internal(): + return ret + else: + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + arg_value = bound_args.arguments.get(self.arg_name) + output_type = determine_array_type(arg_value) + + return _to_output_type(ret, output_type) + + return inner + + ## Example estimator implementation @@ -252,6 +277,20 @@ def example_workflow(): assert isinstance(model.coef_, np.ndarray) + # Example for reflection of types of a stateless function. + + # @convert_cuml_arrays + # @set_output_type("X") + + @convert_cuml_arrays_to_type_of("X") + def power(X, exponent: int): + X = as_cuml_array(X) + result = cp.sqrt(X.to_device_array()) + return as_cuml_array(result) + + squared_X = power(X, 2) + print(type(squared_X)) + if __name__ == "__main__": example_workflow() From 7b2881521ee5344d2f26457d147dbf884fcf56be Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 20:30:16 +0000 Subject: [PATCH 06/22] fixup! Implement prototype for revised type reflection system. --- revised_type_reflection_system_prototype.py | 77 +++++++----- type-reflection-docs.md | 132 ++++++++++++++++++++ 2 files changed, 176 insertions(+), 33 deletions(-) create mode 100644 type-reflection-docs.md diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py index c74d9f4240..5068954b5c 100644 --- a/revised_type_reflection_system_prototype.py +++ b/revised_type_reflection_system_prototype.py @@ -3,6 +3,7 @@ import inspect from collections.abc import Sequence from contextlib import contextmanager +from dataclasses import dataclass from functools import wraps from typing import Any @@ -138,8 +139,11 @@ def fit(self, X, y): Sets the output_type of self to the type of the X argument. """ - def __init__(self, arg_name: str): - self.arg_name = arg_name + def __init__(self, to): + if isinstance(to, str): + to = TypeOfArgument(to) + + self.to = to def __call__(self, func): sig = inspect.signature(func) @@ -150,9 +154,12 @@ def inner(obj, *args, **kwargs): bound_args = sig.bind(obj, *args, **kwargs) bound_args.apply_defaults() - arg_value = bound_args.arguments.get(self.arg_name) - arg_type = determine_array_type(arg_value) - _set_output_type(obj, arg_type) + if isinstance(self.to, TypeOfArgument): + arg_value = bound_args.arguments.get(self.to.argument_name) + arg_type = determine_array_type(arg_value) + _set_output_type(obj, arg_type) + else: + raise TypeError(f"Cannot handle self.to type '{type(self.to)}.") return func(obj, *args, **kwargs) @@ -169,25 +176,19 @@ def _to_output_type(obj, output_type: str): return obj -def convert_cuml_arrays(func): # decorator - """Cuml arrays in method return value are converted.""" +# Sentinels +ObjectOutputType = object() +GlobalOutputType = object() - @wraps(func) - @api_boundary - def inner(obj, *args, **kwargs): - ret = func(obj, *args, **kwargs) - if is_api_internal(): - return ret - else: - output_type = _get_output_type(obj) - return _to_output_type(ret, output_type) - return inner +@dataclass +class TypeOfArgument: + argument_name: str -class convert_cuml_arrays_to_type_of: - def __init__(self, arg_name: str): - self.arg_name = arg_name +class convert_cuml_arrays: + def __init__(self, to=ObjectOutputType): + self.to = to def __call__(self, func): sig = inspect.signature(func) @@ -198,14 +199,25 @@ def inner(*args, **kwargs): ret = func(*args, **kwargs) if is_api_internal(): return ret - else: + elif global_output_type is not None: + return _to_output_type(ret, global_output_type) + elif self.to is ObjectOutputType: + # Use the object's output type. + obj = args[0] + output_type = obj._output_type + elif self.to is GlobalOutputType: + # Always use the global output type. + output_type = global_output_type + elif isinstance(self.to, TypeOfArgument): + # Use the type of the function argument. bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - - arg_value = bound_args.arguments.get(self.arg_name) + arg_value = bound_args.arguments.get(self.to.argument_name) output_type = determine_array_type(arg_value) + else: + raise ValueError(f"Unable to process 'to' argument: {self.to}") - return _to_output_type(ret, output_type) + return _to_output_type(ret, output_type) return inner @@ -218,8 +230,8 @@ class MinimalLinearRegression: coef_ = CumlArrayDescriptor() intercept_ = CumlArrayDescriptor() - # Private methods should not be at the API boundary and should - # not use the @set_output_type decorator. + # Private methods should not be at the API boundary and must + # never use the @set_output_type decorator. def _fit_on_device(self, X: cp.ndarray, y: cp.ndarray): X_design = cp.hstack([cp.ones((X.shape[0], 1)), X]) @@ -245,7 +257,7 @@ def _predict_on_device(self, X: cp.ndarray) -> cp.ndarray: # perform an automatic conversion. return X @ self.coef_.to_device_array() + self.intercept_.to_device_array() - @convert_cuml_arrays + @convert_cuml_arrays(to=ObjectOutputType) def predict(self, X): y = self._predict_on_device(as_cuml_array(X).to_device_array()) @@ -278,18 +290,17 @@ def example_workflow(): assert isinstance(model.coef_, np.ndarray) # Example for reflection of types of a stateless function. - - # @convert_cuml_arrays - # @set_output_type("X") - - @convert_cuml_arrays_to_type_of("X") + @convert_cuml_arrays(to=TypeOfArgument("X")) def power(X, exponent: int): X = as_cuml_array(X) result = cp.sqrt(X.to_device_array()) return as_cuml_array(result) squared_X = power(X, 2) - print(type(squared_X)) + assert isinstance(squared_X, type(X)) + + with override_output_type("cupy"): + assert isinstance(power(X, 2), cp.ndarray) if __name__ == "__main__": diff --git a/type-reflection-docs.md b/type-reflection-docs.md new file mode 100644 index 0000000000..3bd058f7e3 --- /dev/null +++ b/type-reflection-docs.md @@ -0,0 +1,132 @@ +# Revised Type Reflection System + +Author: @csadorf + +## Motivation + +The purpose of the type reflection system is to enable users to provide inputs +and parameters in their preferred format, e.g., numpy or cupy arrays, or pandas +DataFrame objects and have results be returned in the same or format without +needing to worry about any internal processing and conversion of the data. + +For this purpose, estimators will _reflect_ the provided type in most cases. + +## Use cases + +### Data transformation on simple estimator + +Here we assume that a user wants to train a linear regression model on input +data (_X_, _y_) provided in the form of numpy arrays and expects the prediction +result to be returned in the form of a numpy array as well. This is how cuml +would support this: + +```python +# Instantiate and train the estimator +model = MinimalLinearRegression() +model.fit(X, y) +predictions = model.predict(X) +``` + +The type reflection system would guarantee that the following assert holds true: +```python +assert type(predictions) == type(X) +``` + +## Development + +### Estimator Development + +Developing a cuml estimator class that uses the reflection system requires three main components: + +1. The specification of how the desired output type is determined. +2. The specification of which functions should reflect the type and ensuring that to be converted arrays are returned as `CumlArray` type. +3. Specyifying all class attributes that should reflect type are declared as `CumlArrayDescriptor` types. + +Here is a minimal example skipping the actual implementation: + +```python +class MinimalLinearRegression: + + coef_ = CumlArrayDescriptor() + intercept_ = CumlArrayDescriptor() + + @set_output_type("X") + def fit(self, X, y): + ... + return self + + @convert_cuml_arrays() + def predict(self, X): + ... + return CumlArray(y) +``` + +In this case we declared both attributes `coef_` and `intercept_` to be of type +`CumlArrayDescriptor` type which means that they will be automatically converted +to their owner's `output_type` unless the global `output_type` is set. + +The `fit()` method is decorated with the `@set_output_type("X")` decorator which +means that the object's `output_type` should be set to the method's "X" argument +type. + +The `predict()` method is decorated with the `@convert_cuml_arrays()` decorator +which means that `CumlArrays` returned from this function are converted to the +object's `output_type`. + +### Internal vs. external calls + +The type reflection system ensures type consistency for users, however type +conversions should otherwise be avoided to minimize the number of host-device +data transfers. For example, when one cuml estimator calls another cuml +transformer internally, the data should only be copied at the final step when it +is returned to the user. + +To achieve this, we keep track of whether a cuml API call was made externally at +the user-level, or internally. A developer can always check the current API +stack level with the `is_api_internal()` function. + +The `convert_cuml_arrays` decorator will only trigger conversions for external +API calls, right before data is handed back to the user. + +### The global output type + +It is possible to override the dynamic output type by setting the global output type. +Example: + +```python +with override_output_type("cupy"): + ... +``` + +All outputs within this context will be converted to cupy arrays. + +Note: It is **not** possible to opt out of the global output type override. If a +function needs to return a specific type regardless of the global output type +you cannot use the `@convert_cuml_arrays` decorator. This is to ensure +behavioral consistency across the cuml API for both users and developers. + +## Advanced conversion + +The default behavior of the `convert_cuml_arrays()` decorator is to convert cuml arrays + +1. To the global output type if set. +2. The object's output type. + +The behavior can be modified by setting the `to` argument: + +``` +@convert_cuml_arrays() # default behavior +# equivalent to +@convert_cuml_arrays(to=DefaultOutputType) + +If you want to use the + +# Use the type of the argument named "X": +# @convert_cuml_arrays(to=TypeOfArgument("X")) + +# Use the specifically hard-coded type: +# @convert_cuml_arrays(to=SpecificType("cupy")) + +# Always use the globally set output type: +# @convert_cuml_arrays(to=GlobalOutputType) +``` From 6966f3bf01dd75ff7fc4148fd2f9896cf99fbf01 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 21:27:35 +0000 Subject: [PATCH 07/22] fixup! Implement prototype for revised type reflection system. --- revised_type_reflection_system_prototype.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py index 5068954b5c..76d3311fef 100644 --- a/revised_type_reflection_system_prototype.py +++ b/revised_type_reflection_system_prototype.py @@ -180,6 +180,8 @@ def _to_output_type(obj, output_type: str): ObjectOutputType = object() GlobalOutputType = object() +DefaultOutputType = ObjectOutputType + @dataclass class TypeOfArgument: @@ -187,7 +189,7 @@ class TypeOfArgument: class convert_cuml_arrays: - def __init__(self, to=ObjectOutputType): + def __init__(self, to=DefaultOutputType): self.to = to def __call__(self, func): @@ -197,17 +199,21 @@ def __call__(self, func): @api_boundary def inner(*args, **kwargs): ret = func(*args, **kwargs) + + # Internal call, just return the value without further processing. if is_api_internal(): return ret + + # We use the global output type, whenever it is set. elif global_output_type is not None: return _to_output_type(ret, global_output_type) + + # Use the object's output type, assumes that func is a method with self argument. elif self.to is ObjectOutputType: # Use the object's output type. obj = args[0] output_type = obj._output_type - elif self.to is GlobalOutputType: - # Always use the global output type. - output_type = global_output_type + elif isinstance(self.to, TypeOfArgument): # Use the type of the function argument. bound_args = sig.bind(*args, **kwargs) From 964705d46b937d322ca241812b7bb74fe601e790 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 22:50:06 +0000 Subject: [PATCH 08/22] Implement prototype within cuml namespace. --- python/cuml/cuml/internals/global_settings.py | 13 +- python/cuml/cuml/internals/type_reflection.py | 320 ++++++++++++++++++ 2 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 python/cuml/cuml/internals/type_reflection.py diff --git a/python/cuml/cuml/internals/global_settings.py b/python/cuml/cuml/internals/global_settings.py index 9dae3ceac1..c532bbba47 100644 --- a/python/cuml/cuml/internals/global_settings.py +++ b/python/cuml/cuml/internals/global_settings.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021-2024, NVIDIA CORPORATION. +# Copyright (c) 2021-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,6 +53,7 @@ def __init__(self): "accelerator_active": False, "accelerator_loaded": False, "accelerated_modules": {}, + "_api_depth": 0, } ) @@ -134,3 +135,13 @@ def output_type(self, value): @property def xpy(self): return self.memory_type.xpy + + @property + def api_depth(self): + return self._api_depth + + def increment_api_depth(self): + self._api_depth += 1 + + def decrement_api_depth(self): + self._api_depth -= 1 diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py new file mode 100644 index 0000000000..652ba078b5 --- /dev/null +++ b/python/cuml/cuml/internals/type_reflection.py @@ -0,0 +1,320 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. + +from cuml.internals.global_settings import GlobalSettings + +import inspect +from collections.abc import Sequence +from contextlib import contextmanager +from dataclasses import dataclass +from functools import wraps +from typing import Any + +import cupy as cp +import numpy as np + +# Everything related to API boundary determination + + +def api_depth_greater_than_zero() -> bool: + return GlobalSettings().api_depth > 0 + + +def api_depth_greater_than_one() -> bool: + return GlobalSettings().api_depth > 1 + + +def cuml_api(func): + @wraps(func) + def inner(*args, **kwargs): + GlobalSettings().increment_api_depth() + try: + return func(*args, **kwargs) + finally: + GlobalSettings().decrement_api_depth() + + return inner + + +# CumlArray + + +class CumlArray: + def __init__(self, data): + self.data = data + + def to_output(self, output_type: str): + match output_type: + case "numpy": + if isinstance(self.data, cp.ndarray): + return self.data.get() + else: + return np.asarray(self.data) + case "cupy": + return cp.asarray(self.data) + case _: + raise TypeError(f"Unknown output_type '{output_type}'.") + + def to_device_array(self) -> cp.ndarray: + return self.to_output("cupy") + + +def as_cuml_array(X) -> CumlArray: + """Wraps array X in CumlArray container.""" + return CumlArray(X) + + +# CumlArrayDescriptor + + +class CumlArrayDescriptor: + def __init__(self, order="K"): + self.order = order + + def __set_name__(self, owner, name): + self.name = name + + def __set__(self, obj, value): + # Just save the provided value as CumlArray + setattr(obj, f"_{self.name}_value", as_cuml_array(value)) + + def __get__(self, obj, objtype=None): + # Return either the original value for internal access or convert to the + # desired output type. + value = getattr(obj, f"_{self.name}_value") + if api_depth_greater_than_zero(): + return value + else: + output_type = _get_output_type(obj) + return value.to_output(output_type) + + +# Type reflection + +global_output_type = None + + +@contextmanager +def override_output_type(output_type: str): + global global_output_type + try: + previous_output_type = global_output_type + global_output_type = output_type + yield + finally: + global_output_type = previous_output_type + + +def determine_array_type(value) -> str: + """Utility function to identify the array type.""" + if isinstance(value, CumlArray): + return "cuml" + elif isinstance(value, np.ndarray): + return "numpy" + elif isinstance(value, cp.ndarray): + return "cupy" + else: + return ValueError(f"Unknown array type: {type(value)}") + + +def _set_output_type(obj: Any, output_type: str): + setattr(obj, "_output_type", output_type) + + +def _get_output_type(obj: Any) -> str: + if global_output_type is None: + return getattr(obj, "_output_type", None) + else: + return global_output_type + + +class set_output_type: + """Set a object's output_type based on a function argument type. + + Example: + + @set_output_type("X") + def fit(self, X, y): + ... + + Sets the output_type of self to the type of the X argument. + """ + + def __init__(self, to): + if isinstance(to, str): + to = TypeOfArgument(to) + + self.to = to + + def __call__(self, func): + sig = inspect.signature(func) + + @cuml_api + def inner(obj, *args, **kwargs): + if not api_depth_greater_than_one(): + bound_args = sig.bind(obj, *args, **kwargs) + bound_args.apply_defaults() + + if isinstance(self.to, TypeOfArgument): + arg_value = bound_args.arguments.get(self.to.argument_name) + arg_type = determine_array_type(arg_value) + _set_output_type(obj, arg_type) + else: + raise TypeError( + f"Cannot handle self.to type '{type(self.to)}." + ) + + return func(obj, *args, **kwargs) + + return inner + + +def _to_output_type(obj, output_type: str): + """Convert CumlArray and containers of CumlArray.""" + if isinstance(obj, CumlArray): + return obj.to_output(output_type) + elif isinstance(obj, Sequence) and not isinstance(obj, str): + return type(obj)(_to_output_type(item) for item in obj) + else: + return obj + + +# Sentinels +ObjectOutputType = object() +GlobalOutputType = object() + +DefaultOutputType = ObjectOutputType + + +@dataclass +class TypeOfArgument: + argument_name: str + + +class convert_cuml_arrays: + def __init__(self, to=DefaultOutputType): + self.to = to + + def __call__(self, func): + sig = inspect.signature(func) + + @wraps(func) + @cuml_api + def inner(*args, **kwargs): + ret = func(*args, **kwargs) + + # Internal call, just return the value without further processing. + if api_depth_greater_than_one(): + return ret + + # We use the global output type, whenever it is set. + elif global_output_type is not None: + return _to_output_type(ret, global_output_type) + + # Use the object's output type, assumes that func is a method with self argument. + elif self.to is ObjectOutputType: + # Use the object's output type. + obj = args[0] + output_type = obj._output_type + + elif isinstance(self.to, TypeOfArgument): + # Use the type of the function argument. + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + arg_value = bound_args.arguments.get(self.to.argument_name) + output_type = determine_array_type(arg_value) + else: + raise ValueError(f"Unable to process 'to' argument: {self.to}") + + return _to_output_type(ret, output_type) + + return inner + + +# Example estimator implementation + + +class MinimalLinearRegression: + + coef_ = CumlArrayDescriptor() + intercept_ = CumlArrayDescriptor() + + # Private methods should not be at the API boundary and must + # never use the @set_output_type decorator. + + def _fit_on_device(self, X: cp.ndarray, y: cp.ndarray): + X_design = cp.hstack([cp.ones((X.shape[0], 1)), X]) + + # Compute coefficients using normal equation + weights = cp.linalg.pinv(X_design.T @ X_design) @ X_design.T @ y + + # Separate intercept and coefficients + self.intercept_ = weights[0] + self.coef_ = weights[1:] + + @set_output_type("X") + def fit(self, X, y): + # The implementation here is device specific. We delay the conversion to + # CumlArray and then device array to the latest possible moment. + X, y = as_cuml_array(X), as_cuml_array(y) + self._fit_on_device(X.to_device_array(), y.to_device_array()) + + return self + + def _predict_on_device(self, X: cp.ndarray) -> cp.ndarray: + # This is an API internal method, the array descriptor will not(!) + # perform an automatic conversion. + return ( + X @ self.coef_.to_device_array() + + self.intercept_.to_device_array() + ) + + @convert_cuml_arrays(to=ObjectOutputType) + def predict(self, X): + y = self._predict_on_device(as_cuml_array(X).to_device_array()) + + # By returning the result within the CumlArray container in a function + # at the API boundary decorated with @convert_cuml_arrays, we ensure + # that the return value is automatically converted to reflect the desired + # type. + return CumlArray(y) + + +def example_workflow(): + # Example usage + from sklearn.datasets import make_regression + + # Create synthetic data + X, y = make_regression( + n_samples=20, n_features=2, noise=0.1, random_state=42 + ) + + # Instantiate and train the estimator + model = MinimalLinearRegression() + model.fit(X, y) + + # Make predictions + predictions = model.predict(X) + print("Predictions:", predictions, type(predictions)) + print("coef", model.coef_, type(model.coef_)) + + with override_output_type("cupy"): + assert isinstance(model.coef_, cp.ndarray) + + assert isinstance(model.coef_, np.ndarray) + + # Example for reflection of types of a stateless function. + @convert_cuml_arrays(to=TypeOfArgument("X")) + def power(X, exponent: int): + X = as_cuml_array(X) + result = cp.sqrt(X.to_device_array()) + return as_cuml_array(result) + + squared_X = power(X, 2) + assert isinstance(squared_X, type(X)) + + with override_output_type("cupy"): + assert isinstance(power(X, 2), cp.ndarray) + + +if __name__ == "__main__": + example_workflow() From 0e5c55a201e9f9e8676a4ce0817a90adb1d1b3bd Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Thu, 16 Jan 2025 22:56:26 +0000 Subject: [PATCH 09/22] fixup! Implement prototype within cuml namespace. --- python/cuml/cuml/internals/type_reflection.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 652ba078b5..ab485f9dee 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -83,25 +83,24 @@ def __get__(self, obj, objtype=None): value = getattr(obj, f"_{self.name}_value") if api_depth_greater_than_zero(): return value + elif (global_output_type := GlobalSettings().output_type) is not None: + return value.to_output(global_output_type) else: - output_type = _get_output_type(obj) + output_type = obj._output_type return value.to_output(output_type) # Type reflection -global_output_type = None - @contextmanager def override_output_type(output_type: str): - global global_output_type try: - previous_output_type = global_output_type - global_output_type = output_type + previous_output_type = GlobalSettings().output_type + GlobalSettings().output_type = output_type yield finally: - global_output_type = previous_output_type + GlobalSettings().output_type = previous_output_type def determine_array_type(value) -> str: @@ -120,13 +119,6 @@ def _set_output_type(obj: Any, output_type: str): setattr(obj, "_output_type", output_type) -def _get_output_type(obj: Any) -> str: - if global_output_type is None: - return getattr(obj, "_output_type", None) - else: - return global_output_type - - class set_output_type: """Set a object's output_type based on a function argument type. @@ -207,7 +199,9 @@ def inner(*args, **kwargs): return ret # We use the global output type, whenever it is set. - elif global_output_type is not None: + elif ( + global_output_type := GlobalSettings().output_type + ) is not None: return _to_output_type(ret, global_output_type) # Use the object's output type, assumes that func is a method with self argument. From 9b4736e2924f42d2f59ec6d4da8160e7d48f86ec Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 14:40:55 +0000 Subject: [PATCH 10/22] Rename 'cuml_api' -> 'cuml_public_api'. --- python/cuml/cuml/internals/type_reflection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index ab485f9dee..1c177abaa9 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -23,7 +23,7 @@ def api_depth_greater_than_one() -> bool: return GlobalSettings().api_depth > 1 -def cuml_api(func): +def cuml_public_api(func): @wraps(func) def inner(*args, **kwargs): GlobalSettings().increment_api_depth() @@ -140,7 +140,7 @@ def __init__(self, to): def __call__(self, func): sig = inspect.signature(func) - @cuml_api + @cuml_public_api def inner(obj, *args, **kwargs): if not api_depth_greater_than_one(): bound_args = sig.bind(obj, *args, **kwargs) @@ -190,7 +190,7 @@ def __call__(self, func): sig = inspect.signature(func) @wraps(func) - @cuml_api + @cuml_public_api def inner(*args, **kwargs): ret = func(*args, **kwargs) From 815c6e3022708faddeb670b233e2614ffc22a383 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 14:42:54 +0000 Subject: [PATCH 11/22] Use cuml's CumlArray. --- python/cuml/cuml/internals/type_reflection.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 1c177abaa9..3feb6f830a 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. from cuml.internals.global_settings import GlobalSettings +from cuml.internals.array import CumlArray import inspect from collections.abc import Sequence @@ -35,27 +36,7 @@ def inner(*args, **kwargs): return inner -# CumlArray - - -class CumlArray: - def __init__(self, data): - self.data = data - - def to_output(self, output_type: str): - match output_type: - case "numpy": - if isinstance(self.data, cp.ndarray): - return self.data.get() - else: - return np.asarray(self.data) - case "cupy": - return cp.asarray(self.data) - case _: - raise TypeError(f"Unknown output_type '{output_type}'.") - - def to_device_array(self) -> cp.ndarray: - return self.to_output("cupy") +# # CumlArray def as_cuml_array(X) -> CumlArray: From a5b1bc52985daa06bb79e39bd6303be9ffbf8be0 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 15:48:57 +0000 Subject: [PATCH 12/22] Tiny revision to the new CumlArrayDescriptor class implementation. --- python/cuml/cuml/internals/type_reflection.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 3feb6f830a..27a6614549 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -56,19 +56,36 @@ def __set_name__(self, owner, name): def __set__(self, obj, value): # Just save the provided value as CumlArray - setattr(obj, f"_{self.name}_value", as_cuml_array(value)) + setattr(obj, f"_{self.name}_data", as_cuml_array(value)) - def __get__(self, obj, objtype=None): - # Return either the original value for internal access or convert to the - # desired output type. - value = getattr(obj, f"_{self.name}_value") + def __get__(self, obj, _=None): + + if ( + obj is None + ): # descriptor was accessed on class rather than instance + return self + + # Get data from the owning object + array = getattr(obj, f"_{self.name}_data") + + # This is accessed internally, just return the cuml array directly. if api_depth_greater_than_zero(): - return value + return array + + # The global output type is set, return the array converted to that. elif (global_output_type := GlobalSettings().output_type) is not None: - return value.to_output(global_output_type) + return array.to_output(global_output_type) + + # Return the array converted to the object's _output_type + elif (output_type := obj._output_type) is not None: + return array.to_output(output_type) + + # Neither the global nor the object's output_type are set. Since this + # is a user call, we must fail. else: - output_type = obj._output_type - return value.to_output(output_type) + raise RuntimeError( + "Tried to access CumlArrayDescriptor without output_type set." + ) # Type reflection From 2456eb9421119ae2eb8ac65d36366e4e8b719f6f Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 16:16:33 +0000 Subject: [PATCH 13/22] Revise internal_api related names. --- python/cuml/cuml/internals/type_reflection.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 27a6614549..9d9138667e 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -16,11 +16,7 @@ # Everything related to API boundary determination -def api_depth_greater_than_zero() -> bool: - return GlobalSettings().api_depth > 0 - - -def api_depth_greater_than_one() -> bool: +def in_internal_api() -> bool: return GlobalSettings().api_depth > 1 @@ -58,6 +54,7 @@ def __set__(self, obj, value): # Just save the provided value as CumlArray setattr(obj, f"_{self.name}_data", as_cuml_array(value)) + @cuml_public_api def __get__(self, obj, _=None): if ( @@ -69,7 +66,7 @@ def __get__(self, obj, _=None): array = getattr(obj, f"_{self.name}_data") # This is accessed internally, just return the cuml array directly. - if api_depth_greater_than_zero(): + if in_internal_api(): return array # The global output type is set, return the array converted to that. @@ -140,7 +137,7 @@ def __call__(self, func): @cuml_public_api def inner(obj, *args, **kwargs): - if not api_depth_greater_than_one(): + if not in_internal_api(): bound_args = sig.bind(obj, *args, **kwargs) bound_args.apply_defaults() @@ -193,7 +190,7 @@ def inner(*args, **kwargs): ret = func(*args, **kwargs) # Internal call, just return the value without further processing. - if api_depth_greater_than_one(): + if in_internal_api(): return ret # We use the global output type, whenever it is set. From 04230f14b4c6ec4ef52c6172ba6b872acff92be4 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 16:23:16 +0000 Subject: [PATCH 14/22] further refine --- python/cuml/cuml/internals/type_reflection.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 9d9138667e..3aaaadeba6 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -57,9 +57,8 @@ def __set__(self, obj, value): @cuml_public_api def __get__(self, obj, _=None): - if ( - obj is None - ): # descriptor was accessed on class rather than instance + # The descriptor was accessed on a class rather than an instance. + if obj is None: return self # Get data from the owning object @@ -84,6 +83,10 @@ def __get__(self, obj, _=None): "Tried to access CumlArrayDescriptor without output_type set." ) + def __delete__(self, obj): + if obj is not None: + delattr(obj, f"_{self.name}_data") + # Type reflection @@ -135,6 +138,7 @@ def __init__(self, to): def __call__(self, func): sig = inspect.signature(func) + @wraps(func) @cuml_public_api def inner(obj, *args, **kwargs): if not in_internal_api(): From a36cea7ca37973149a438ce875455e04b5c6ed72 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 16:46:04 +0000 Subject: [PATCH 15/22] Implement CumlArrayDescriptor caching. --- python/cuml/cuml/internals/type_reflection.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 3aaaadeba6..e09d4c6cc9 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -51,8 +51,17 @@ def __set_name__(self, owner, name): self.name = name def __set__(self, obj, value): - # Just save the provided value as CumlArray + # Save the provided value as CumlArray and initialize output cache. setattr(obj, f"_{self.name}_data", as_cuml_array(value)) + setattr(obj, f"_{self.name}_output_cache", dict()) + + def _to_cached_output(self, obj, array, output_type): + output_cache = getattr(obj, f"_{self.name}_output_cache") + + if output_type not in output_cache: + output_cache[output_type] = array.to_output(output_type) + + return output_cache[output_type] @cuml_public_api def __get__(self, obj, _=None): @@ -70,11 +79,11 @@ def __get__(self, obj, _=None): # The global output type is set, return the array converted to that. elif (global_output_type := GlobalSettings().output_type) is not None: - return array.to_output(global_output_type) + return self._to_cached_output(obj, array, global_output_type) # Return the array converted to the object's _output_type elif (output_type := obj._output_type) is not None: - return array.to_output(output_type) + return self._to_cached_output(obj, array, output_type) # Neither the global nor the object's output_type are set. Since this # is a user call, we must fail. From bc8d5e2618510d9b40b2fad61f10dee5673d670c Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 20:43:37 +0000 Subject: [PATCH 16/22] Implement revision for LinearRegression estimator. --- python/cuml/cuml/common/__init__.py | 11 +- python/cuml/cuml/internals/base.pyx | 141 +----------------- python/cuml/cuml/internals/type_reflection.py | 35 ++++- python/cuml/cuml/linear_model/base.pyx | 9 +- .../cuml/linear_model/linear_regression.pyx | 5 +- 5 files changed, 48 insertions(+), 153 deletions(-) diff --git a/python/cuml/cuml/common/__init__.py b/python/cuml/cuml/common/__init__.py index e267bf668b..6cf98be743 100644 --- a/python/cuml/cuml/common/__init__.py +++ b/python/cuml/cuml/common/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,6 +40,12 @@ from cuml.internals.memory_utils import with_cupy_rmm from cuml.common.device_selection import using_device_type +from cuml.internals.type_reflection import ( + CumlArrayDescriptor, + set_output_type, + convert_cuml_arrays, +) + if is_cuda_available(): from cuml.common.pointer_utils import device_of_gpu_matrix @@ -51,6 +57,7 @@ __all__ = [ "CumlArray", + "CumlArrayDescriptor", "SparseCumlArray", "device_of_gpu_matrix", "has_cupy", @@ -67,6 +74,8 @@ "using_memory_type", "using_output_type", "with_cupy_rmm", + "convert_cuml_arrays", + "set_output_type", "sparse_scipy_to_cp", "timed", ] diff --git a/python/cuml/cuml/internals/base.pyx b/python/cuml/cuml/internals/base.pyx index 32b2cad908..9fce926694 100644 --- a/python/cuml/cuml/internals/base.pyx +++ b/python/cuml/cuml/internals/base.pyx @@ -109,8 +109,7 @@ class VerbosityDescriptor: obj._verbose = value -class Base(TagsMixin, - metaclass=cuml.internals.BaseMetaClass): +class Base(TagsMixin): """ Base class for all the ML algos. It handles some of the common operations across all algos. Every ML algo class exposed at cython level must inherit @@ -283,15 +282,6 @@ class Base(TagsMixin, # rendered unnecessary with https://github.com/rapidsai/cuml/pull/6189. GlobalSettings().root_cm = GlobalSettings().prev_root_cm - self.output_type = _check_output_type_str( - cuml.global_settings.output_type - if output_type is None else output_type) - if output_mem_type is None: - self.output_mem_type = cuml.global_settings.memory_type - else: - self.output_mem_type = MemoryType.from_str(output_mem_type) - self._input_type = None - self._input_mem_type = None self.target_dtype = None self.n_features_in_ = None @@ -397,115 +387,6 @@ class Base(TagsMixin, else: raise AttributeError(attr) - def _set_base_attributes(self, - output_type=None, - target_dtype=None, - n_features=None): - """ - Method to set the base class attributes - output type, - target dtype and n_features. It combines the three different - function calls. It's called in fit function from estimators. - - Parameters - -------- - output_type : DataFrame (default = None) - Is output_type is passed, aets the output_type on the - dataframe passed - target_dtype : Target column (default = None) - If target_dtype is passed, we call _set_target_dtype - on it - n_features: int or DataFrame (default=None) - If an int is passed, we set it to the number passed - If dataframe, we set it based on the passed df. - - Examples - -------- - - .. code-block:: python - - # To set output_type and n_features based on X - self._set_base_attributes(output_type=X, n_features=X) - - # To set output_type on X and n_features to 10 - self._set_base_attributes(output_type=X, n_features=10) - - # To only set target_dtype - self._set_base_attributes(output_type=X, target_dtype=y) - """ - if output_type is not None: - self._set_output_type(output_type) - self._set_output_mem_type(output_type) - if target_dtype is not None: - self._set_target_dtype(target_dtype) - if n_features is not None: - self._set_n_features_in(n_features) - - def _set_output_type(self, inp): - self._input_type = determine_array_type(inp) - - def _set_output_mem_type(self, inp): - self._input_mem_type = determine_array_memtype( - inp - ) - - def _get_output_type(self, inp): - """ - Method to be called by predict/transform methods of inheriting classes. - Returns the appropriate output type depending on the type of the input, - class output type and global output type. - """ - - # Default to the global type - output_type = cuml.global_settings.output_type - - # If its None, default to our type - if (output_type is None or output_type == "mirror"): - output_type = self.output_type - - # If we are input, get the type from the input - if output_type == 'input': - output_type = determine_array_type(inp) - - return output_type - - def _get_output_mem_type(self, inp): - """ - Method to be called by predict/transform methods of inheriting classes. - Returns the appropriate memory type depending on the type of the input, - class output type and global output type. - """ - - # Default to the global type - mem_type = cuml.global_settings.memory_type - - # If we are input, get the type from the input - if cuml.global_settings.output_type == 'input': - mem_type = determine_array_memtype(inp) - - return mem_type - - def _set_target_dtype(self, target): - self.target_dtype = cuml.internals.input_utils.determine_array_dtype( - target) - - def _get_target_dtype(self): - """ - Method to be called by predict/transform methods of - inheriting classifier classes. Returns the appropriate output - dtype depending on the dtype of the target. - """ - try: - out_dtype = self.target_dtype - except AttributeError: - out_dtype = None - return out_dtype - - def _set_n_features_in(self, X): - if isinstance(X, int): - self.n_features_in_ = X - else: - self.n_features_in_ = X.shape[1] - def _more_tags(self): # 'preserves_dtype' tag's Scikit definition currently only applies to # transformers and whether the transform method conserves the dtype @@ -593,26 +474,6 @@ def _check_output_type_str(output_str): ) -def _determine_stateless_output_type(output_type, input_obj): - """ - This function determines the output type using the same steps that are - performed in `cuml.common.base.Base`. This can be used to mimic the - functionality in `Base` for stateless functions or objects that do not - derive from `Base`. - """ - - # Default to the global type if not specified, otherwise, check the - # output_type string - temp_output = cuml.global_settings.output_type if output_type is None \ - else _check_output_type_str(output_type) - - # If we are using 'input', determine the the type from the input object - if temp_output == 'input': - temp_output = determine_array_type(input_obj) - - return temp_output - - class UniversalBase(Base): # variable to enable dispatching non-implemented methods to CPU # estimators, experimental. diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index e09d4c6cc9..6e3a0f16e5 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -2,6 +2,7 @@ from cuml.internals.global_settings import GlobalSettings from cuml.internals.array import CumlArray +from cuml.internals.input_utils import input_to_cuml_array import inspect from collections.abc import Sequence @@ -32,15 +33,17 @@ def inner(*args, **kwargs): return inner -# # CumlArray +# CumlArray -def as_cuml_array(X) -> CumlArray: +def as_cuml_array(X, dtype=None) -> CumlArray: """Wraps array X in CumlArray container.""" - return CumlArray(X) + # TODO: After debugging, replace this with an immediate call to the CumlArray constructor. + return CumlArray(X, dtype=dtype) # CumlArrayDescriptor +# TODO: Replace CumlArrayDescriptor in cuml/common/array_descriptor.py class CumlArrayDescriptor: @@ -52,7 +55,12 @@ def __set_name__(self, owner, name): def __set__(self, obj, value): # Save the provided value as CumlArray and initialize output cache. - setattr(obj, f"_{self.name}_data", as_cuml_array(value)) + dtype = _get_dtype(obj) + setattr( + obj, + f"_{self.name}_data", + None if value is None else as_cuml_array(value, dtype), + ) setattr(obj, f"_{self.name}_output_cache", dict()) def _to_cached_output(self, obj, array, output_type): @@ -82,7 +90,7 @@ def __get__(self, obj, _=None): return self._to_cached_output(obj, array, global_output_type) # Return the array converted to the object's _output_type - elif (output_type := obj._output_type) is not None: + elif (output_type := getattr(obj, "_output_type", None)) is not None: return self._to_cached_output(obj, array, output_type) # Neither the global nor the object's output_type are set. Since this @@ -122,10 +130,22 @@ def determine_array_type(value) -> str: return ValueError(f"Unknown array type: {type(value)}") +def determine_array_dtype(value): + return value.dtype + + def _set_output_type(obj: Any, output_type: str): setattr(obj, "_output_type", output_type) +def _set_dtype(obj: Any, dtype): + setattr(obj, "dtype", dtype) + + +def _get_dtype(obj: Any): + return getattr(obj, "dtype", None) + + class set_output_type: """Set a object's output_type based on a function argument type. @@ -138,11 +158,12 @@ def fit(self, X, y): Sets the output_type of self to the type of the X argument. """ - def __init__(self, to): + def __init__(self, to, dtype=None): if isinstance(to, str): to = TypeOfArgument(to) self.to = to + self.dtype = dtype def __call__(self, func): sig = inspect.signature(func) @@ -157,7 +178,9 @@ def inner(obj, *args, **kwargs): if isinstance(self.to, TypeOfArgument): arg_value = bound_args.arguments.get(self.to.argument_name) arg_type = determine_array_type(arg_value) + dtype = self.dtype or determine_array_dtype(arg_value) _set_output_type(obj, arg_type) + _set_dtype(obj, dtype) else: raise TypeError( f"Cannot handle self.to type '{type(self.to)}." diff --git a/python/cuml/cuml/linear_model/base.pyx b/python/cuml/cuml/linear_model/base.pyx index 718051658a..cc3de5c174 100644 --- a/python/cuml/cuml/linear_model/base.pyx +++ b/python/cuml/cuml/linear_model/base.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ from cuml.internals.array import CumlArray from cuml.internals.input_utils import input_to_cuml_array from cuml.common.doc_utils import generate_docstring from cuml.internals.api_decorators import enable_device_interop +from cuml.common import convert_cuml_arrays IF GPUBUILD == 1: @@ -59,8 +60,8 @@ class LinearPredictMixin: 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - @cuml.internals.api_base_return_array_skipall @enable_device_interop + @convert_cuml_arrays() def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts `y` values for `X`. @@ -109,7 +110,7 @@ class LinearPredictMixin: _n_rows, _n_cols, _coef_ptr, - self.intercept_, + self.intercept_.to_device_array(), _preds_ptr) else: gemmPredict(handle_[0], @@ -117,7 +118,7 @@ class LinearPredictMixin: _n_rows, _n_cols, _coef_ptr, - self.intercept_, + self.intercept_.to_device_array(), _preds_ptr) self.handle.sync() diff --git a/python/cuml/cuml/linear_model/linear_regression.pyx b/python/cuml/cuml/linear_model/linear_regression.pyx index 4da9ee7103..6104e6ebce 100644 --- a/python/cuml/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/cuml/linear_model/linear_regression.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ cuda = gpu_only_import_from('numba', 'cuda') from libc.stdint cimport uintptr_t from cuml.internals.array import CumlArray -from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common import CumlArrayDescriptor, set_output_type from cuml.internals.base import UniversalBase from cuml.internals.mixins import RegressorMixin, FMajorInputTagMixin from cuml.common.doc_utils import generate_docstring @@ -314,6 +314,7 @@ class LinearRegression(LinearPredictMixin, @generate_docstring() @enable_device_interop + @set_output_type("X") def fit(self, X, y, convert_dtype=True, sample_weight=None) -> "LinearRegression": """ From 3a3aacd95399e421bf6de3bcbb4fd3092382e405 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 20:59:00 +0000 Subject: [PATCH 17/22] Update docs and add example workflow module. --- test_linear_model.py | 22 ++++++++++++++++++++++ type-reflection-docs.md | 12 ++++++------ 2 files changed, 28 insertions(+), 6 deletions(-) create mode 100644 test_linear_model.py diff --git a/test_linear_model.py b/test_linear_model.py new file mode 100644 index 0000000000..9656405b90 --- /dev/null +++ b/test_linear_model.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. + +from cuml import LinearRegression + +from sklearn.datasets import make_regression + +# Create synthetic data +X, y = make_regression(n_samples=20, n_features=2, noise=0.1, random_state=42) + +# Instantiate and train the estimator +model = LinearRegression() +model.fit(X, y) + +# Make predictions +predictions = model.predict(X) +print("Predictions:", predictions, type(predictions)) +print("coef", model.coef_, type(model.coef_)) + +# with override_output_type("cupy"): + # assert isinstance(model.coef_, cp.ndarray) +# +# assert isinstance(model.coef_, np.ndarray) diff --git a/type-reflection-docs.md b/type-reflection-docs.md index 3bd058f7e3..927f19ceb7 100644 --- a/type-reflection-docs.md +++ b/type-reflection-docs.md @@ -83,7 +83,7 @@ is returned to the user. To achieve this, we keep track of whether a cuml API call was made externally at the user-level, or internally. A developer can always check the current API -stack level with the `is_api_internal()` function. +stack level with the `in_internal_api()` function. The `convert_cuml_arrays` decorator will only trigger conversions for external API calls, right before data is handed back to the user. @@ -98,7 +98,8 @@ with override_output_type("cupy"): ... ``` -All outputs within this context will be converted to cupy arrays. +All outputs within this context will be converted to cupy arrays regardless of +any other configuration. Note: It is **not** possible to opt out of the global output type override. If a function needs to return a specific type regardless of the global output type @@ -112,6 +113,8 @@ The default behavior of the `convert_cuml_arrays()` decorator is to convert cuml 1. To the global output type if set. 2. The object's output type. +_The function will fail if neither is set._ + The behavior can be modified by setting the `to` argument: ``` @@ -122,10 +125,7 @@ The behavior can be modified by setting the `to` argument: If you want to use the # Use the type of the argument named "X": -# @convert_cuml_arrays(to=TypeOfArgument("X")) - -# Use the specifically hard-coded type: -# @convert_cuml_arrays(to=SpecificType("cupy")) +# @convert_cuml_arrays(to=(DefaultOutputType, TypeOfArgument("X"))) # Always use the globally set output type: # @convert_cuml_arrays(to=GlobalOutputType) From 3d7decd680dafa97114c8f4af8fc1fbb97d30143 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 21:02:10 +0000 Subject: [PATCH 18/22] fixup! Update docs and add example workflow module. --- test_linear_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test_linear_model.py b/test_linear_model.py index 9656405b90..5fe9daad3b 100644 --- a/test_linear_model.py +++ b/test_linear_model.py @@ -1,7 +1,10 @@ # Copyright (c) 2025, NVIDIA CORPORATION. from cuml import LinearRegression +from cuml import using_output_type +import numpy as np +import cupy as cp from sklearn.datasets import make_regression # Create synthetic data @@ -16,7 +19,7 @@ print("Predictions:", predictions, type(predictions)) print("coef", model.coef_, type(model.coef_)) -# with override_output_type("cupy"): - # assert isinstance(model.coef_, cp.ndarray) -# -# assert isinstance(model.coef_, np.ndarray) +with using_output_type("cupy"): + assert isinstance(model.coef_, cp.ndarray) + +assert isinstance(model.coef_, np.ndarray) From 93cd87db684d01fa97c14613c61349c2c12342c8 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Fri, 17 Jan 2025 21:12:45 +0000 Subject: [PATCH 19/22] Use pre-existing CM for overriding global output type. --- python/cuml/cuml/internals/type_reflection.py | 17 +++-------------- revised_type_reflection_system_prototype.py | 6 +++--- type-reflection-docs.md | 2 +- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 6e3a0f16e5..9a879d0e8b 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -2,7 +2,7 @@ from cuml.internals.global_settings import GlobalSettings from cuml.internals.array import CumlArray -from cuml.internals.input_utils import input_to_cuml_array +from cuml.internals.memory_utils import using_output_type import inspect from collections.abc import Sequence @@ -108,16 +108,6 @@ def __delete__(self, obj): # Type reflection -@contextmanager -def override_output_type(output_type: str): - try: - previous_output_type = GlobalSettings().output_type - GlobalSettings().output_type = output_type - yield - finally: - GlobalSettings().output_type = previous_output_type - - def determine_array_type(value) -> str: """Utility function to identify the array type.""" if isinstance(value, CumlArray): @@ -203,7 +193,6 @@ def _to_output_type(obj, output_type: str): # Sentinels ObjectOutputType = object() -GlobalOutputType = object() DefaultOutputType = ObjectOutputType @@ -322,7 +311,7 @@ def example_workflow(): print("Predictions:", predictions, type(predictions)) print("coef", model.coef_, type(model.coef_)) - with override_output_type("cupy"): + with using_output_type("cupy"): assert isinstance(model.coef_, cp.ndarray) assert isinstance(model.coef_, np.ndarray) @@ -337,7 +326,7 @@ def power(X, exponent: int): squared_X = power(X, 2) assert isinstance(squared_X, type(X)) - with override_output_type("cupy"): + with using_output_type("cupy"): assert isinstance(power(X, 2), cp.ndarray) diff --git a/revised_type_reflection_system_prototype.py b/revised_type_reflection_system_prototype.py index 76d3311fef..6d9dd82993 100644 --- a/revised_type_reflection_system_prototype.py +++ b/revised_type_reflection_system_prototype.py @@ -94,7 +94,7 @@ def __get__(self, obj, objtype=None): @contextmanager -def override_output_type(output_type: str): +def using_output_type(output_type: str): global global_output_type try: previous_output_type = global_output_type @@ -290,7 +290,7 @@ def example_workflow(): print("Predictions:", predictions, type(predictions)) print("coef", model.coef_, type(model.coef_)) - with override_output_type("cupy"): + with using_output_type("cupy"): assert isinstance(model.coef_, cp.ndarray) assert isinstance(model.coef_, np.ndarray) @@ -305,7 +305,7 @@ def power(X, exponent: int): squared_X = power(X, 2) assert isinstance(squared_X, type(X)) - with override_output_type("cupy"): + with using_output_type("cupy"): assert isinstance(power(X, 2), cp.ndarray) diff --git a/type-reflection-docs.md b/type-reflection-docs.md index 927f19ceb7..764bdc7747 100644 --- a/type-reflection-docs.md +++ b/type-reflection-docs.md @@ -94,7 +94,7 @@ It is possible to override the dynamic output type by setting the global output Example: ```python -with override_output_type("cupy"): +with using_output_type("cupy"): ... ``` From d27dc22d0fbd141779b636309d8f48596a6b9405 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Tue, 21 Jan 2025 19:53:30 +0000 Subject: [PATCH 20/22] Implement logistic regression. --- python/cuml/cuml/internals/base.pyx | 2 + python/cuml/cuml/internals/mixins.py | 8 +- python/cuml/cuml/internals/type_reflection.py | 90 ------------------- .../cuml/linear_model/logistic_regression.pyx | 15 ++-- python/cuml/cuml/linear_model/ridge.pyx | 5 +- python/cuml/cuml/solvers/qn.pyx | 12 +-- test_linear_model.py | 9 +- 7 files changed, 30 insertions(+), 111 deletions(-) diff --git a/python/cuml/cuml/internals/base.pyx b/python/cuml/cuml/internals/base.pyx index 9fce926694..6788fe7b67 100644 --- a/python/cuml/cuml/internals/base.pyx +++ b/python/cuml/cuml/internals/base.pyx @@ -380,6 +380,8 @@ class Base(TagsMixin): """ Redirects to `solver_model` if the attribute exists. """ + # TODO: I think we should handle this explicitly on a + # estimator-by-estimator basis. if attr == "solver_model": return self.__dict__['solver_model'] if "solver_model" in self.__dict__.keys(): diff --git a/python/cuml/cuml/internals/mixins.py b/python/cuml/cuml/internals/mixins.py index c47dc56754..3da1674948 100644 --- a/python/cuml/cuml/internals/mixins.py +++ b/python/cuml/cuml/internals/mixins.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +18,9 @@ from copy import deepcopy from cuml.common.doc_utils import generate_docstring -from cuml.internals.api_decorators import api_base_return_any_skipall from cuml.internals.base_helpers import _tags_class_and_instance from cuml.internals.api_decorators import enable_device_interop +from cuml.common import convert_cuml_arrays ############################################################################### @@ -198,8 +198,8 @@ class RegressorMixin: "description": "R^2 of self.predict(X) " "wrt. y.", } ) - @api_base_return_any_skipall @enable_device_interop + @convert_cuml_arrays() def score(self, X, y, **kwargs): """ Scoring function for regression estimators @@ -239,8 +239,8 @@ class ClassifierMixin: ), } ) - @api_base_return_any_skipall @enable_device_interop + @convert_cuml_arrays() def score(self, X, y, **kwargs): """ Scoring function for classifier estimators based on mean accuracy. diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 9a879d0e8b..47ca60fc91 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -242,93 +242,3 @@ def inner(*args, **kwargs): return _to_output_type(ret, output_type) return inner - - -# Example estimator implementation - - -class MinimalLinearRegression: - - coef_ = CumlArrayDescriptor() - intercept_ = CumlArrayDescriptor() - - # Private methods should not be at the API boundary and must - # never use the @set_output_type decorator. - - def _fit_on_device(self, X: cp.ndarray, y: cp.ndarray): - X_design = cp.hstack([cp.ones((X.shape[0], 1)), X]) - - # Compute coefficients using normal equation - weights = cp.linalg.pinv(X_design.T @ X_design) @ X_design.T @ y - - # Separate intercept and coefficients - self.intercept_ = weights[0] - self.coef_ = weights[1:] - - @set_output_type("X") - def fit(self, X, y): - # The implementation here is device specific. We delay the conversion to - # CumlArray and then device array to the latest possible moment. - X, y = as_cuml_array(X), as_cuml_array(y) - self._fit_on_device(X.to_device_array(), y.to_device_array()) - - return self - - def _predict_on_device(self, X: cp.ndarray) -> cp.ndarray: - # This is an API internal method, the array descriptor will not(!) - # perform an automatic conversion. - return ( - X @ self.coef_.to_device_array() - + self.intercept_.to_device_array() - ) - - @convert_cuml_arrays(to=ObjectOutputType) - def predict(self, X): - y = self._predict_on_device(as_cuml_array(X).to_device_array()) - - # By returning the result within the CumlArray container in a function - # at the API boundary decorated with @convert_cuml_arrays, we ensure - # that the return value is automatically converted to reflect the desired - # type. - return CumlArray(y) - - -def example_workflow(): - # Example usage - from sklearn.datasets import make_regression - - # Create synthetic data - X, y = make_regression( - n_samples=20, n_features=2, noise=0.1, random_state=42 - ) - - # Instantiate and train the estimator - model = MinimalLinearRegression() - model.fit(X, y) - - # Make predictions - predictions = model.predict(X) - print("Predictions:", predictions, type(predictions)) - print("coef", model.coef_, type(model.coef_)) - - with using_output_type("cupy"): - assert isinstance(model.coef_, cp.ndarray) - - assert isinstance(model.coef_, np.ndarray) - - # Example for reflection of types of a stateless function. - @convert_cuml_arrays(to=TypeOfArgument("X")) - def power(X, exponent: int): - X = as_cuml_array(X) - result = cp.sqrt(X.to_device_array()) - return as_cuml_array(result) - - squared_X = power(X, 2) - assert isinstance(squared_X, type(X)) - - with using_output_type("cupy"): - assert isinstance(power(X, 2), cp.ndarray) - - -if __name__ == "__main__": - example_workflow() diff --git a/python/cuml/cuml/linear_model/logistic_regression.pyx b/python/cuml/cuml/linear_model/logistic_regression.pyx index e968093c8e..44f1f51825 100644 --- a/python/cuml/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/cuml/linear_model/logistic_regression.pyx @@ -26,12 +26,14 @@ import cuml.internals from cuml.solvers import QN from cuml.internals.base import UniversalBase from cuml.internals.mixins import ClassifierMixin, FMajorInputTagMixin, SparseInputTagMixin -from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common import CumlArrayDescriptor from cuml.internals.array import CumlArray from cuml.common.doc_utils import generate_docstring from cuml.internals import logger from cuml.common import input_to_cuml_array from cuml.common import using_output_type +from cuml.common import set_output_type +from cuml.common import convert_cuml_arrays from cuml.internals.api_decorators import device_interop_preparation from cuml.internals.api_decorators import enable_device_interop cp = gpu_only_import('cupy') @@ -286,8 +288,8 @@ class LogisticRegression(UniversalBase, self.verb_prefix = "" @generate_docstring(X='dense_sparse') - @cuml.internals.api_base_return_any(set_output_dtype=True) @enable_device_interop + @set_output_type("X") def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "LogisticRegression": """ @@ -390,6 +392,7 @@ class LogisticRegression(UniversalBase, 'description': 'Confidence score', 'shape': '(n_samples, n_classes)'}) @enable_device_interop + @convert_cuml_arrays() def decision_function(self, X, convert_dtype=True) -> CumlArray: """ Gives confidence score for X @@ -405,8 +408,8 @@ class LogisticRegression(UniversalBase, 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - @cuml.internals.api_base_return_array(get_output_dtype=True) @enable_device_interop + @convert_cuml_arrays() def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the y for X. @@ -421,6 +424,7 @@ class LogisticRegression(UniversalBase, probabilities', 'shape': '(n_samples, n_classes)'}) @enable_device_interop + @convert_cuml_arrays() def predict_proba(self, X, convert_dtype=True) -> CumlArray: """ Predicts the class probabilities for each class in X @@ -438,6 +442,7 @@ class LogisticRegression(UniversalBase, class probabilities', 'shape': '(n_samples, n_classes)'}) @enable_device_interop + @convert_cuml_arrays() def predict_log_proba(self, X, convert_dtype=True) -> CumlArray: """ Predicts the log class probabilities for each class in X @@ -529,7 +534,7 @@ class LogisticRegression(UniversalBase, return self @property - @cuml.internals.api_base_return_array_skipall + @convert_cuml_arrays() def coef_(self): return self.solver_model.coef_ @@ -538,7 +543,7 @@ class LogisticRegression(UniversalBase, self.solver_model.coef_ = value @property - @cuml.internals.api_base_return_array_skipall + @convert_cuml_arrays() def intercept_(self): return self.solver_model.intercept_ diff --git a/python/cuml/cuml/linear_model/ridge.pyx b/python/cuml/cuml/linear_model/ridge.pyx index bd039867f3..26701baad2 100644 --- a/python/cuml/cuml/linear_model/ridge.pyx +++ b/python/cuml/cuml/linear_model/ridge.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ import warnings from libc.stdint cimport uintptr_t -from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common import CumlArrayDescriptor, set_output_type, convert_cuml_arrays from cuml.internals.base import UniversalBase from cuml.internals.mixins import RegressorMixin, FMajorInputTagMixin from cuml.internals.array import CumlArray @@ -257,6 +257,7 @@ class Ridge(UniversalBase, @generate_docstring() @enable_device_interop + @set_output_type("X") def fit(self, X, y, convert_dtype=True, sample_weight=None) -> "Ridge": """ Fit the model with X and y. diff --git a/python/cuml/cuml/solvers/qn.pyx b/python/cuml/cuml/solvers/qn.pyx index 72f51c25b2..a30fb7bc2e 100644 --- a/python/cuml/cuml/solvers/qn.pyx +++ b/python/cuml/cuml/solvers/qn.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,13 +25,14 @@ from libc.stdint cimport uintptr_t import cuml.internals from cuml.internals.array import CumlArray from cuml.internals.base import Base -from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common import CumlArrayDescriptor from cuml.internals.array_sparse import SparseCumlArray from cuml.internals.global_settings import GlobalSettings from cuml.common.doc_utils import generate_docstring from cuml.common import input_to_cuml_array from cuml.internals.mixins import FMajorInputTagMixin from cuml.common.sparse_utils import is_sparse +from cuml.common import convert_cuml_arrays, set_output_type IF GPUBUILD == 1: @@ -435,7 +436,7 @@ class QN(Base, self.loss = loss @property - @cuml.internals.api_base_return_array_skipall + @convert_cuml_arrays() def coef_(self): if self._coef_ is None: return None @@ -456,6 +457,7 @@ class QN(Base, self._coef_ = value @generate_docstring(X='dense_sparse') + @set_output_type("X") def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "QN": """ Fit the model with X and y. @@ -642,7 +644,6 @@ class QN(Base, return self - @cuml.internals.api_base_return_array_skipall def _decision_function(self, X, convert_dtype=True) -> CumlArray: """ Gives confidence score for X @@ -786,7 +787,7 @@ class QN(Base, 'description': 'Predicted values', 'shape': '(n_samples, 1)' }) - @cuml.internals.api_base_return_array(get_output_dtype=True) + @convert_cuml_arrays() def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the y for X. @@ -905,6 +906,7 @@ class QN(Base, return preds + @convert_cuml_arrays() def score(self, X, y): if GPUBUILD == 1: return accuracy_score(y, self.predict(X)) diff --git a/test_linear_model.py b/test_linear_model.py index 5fe9daad3b..9ad7453277 100644 --- a/test_linear_model.py +++ b/test_linear_model.py @@ -1,6 +1,8 @@ # Copyright (c) 2025, NVIDIA CORPORATION. -from cuml import LinearRegression +# from cuml import LinearRegression as Estimator +# from cuml import Ridge as Estimator +from cuml import LogisticRegression as Estimator from cuml import using_output_type import numpy as np @@ -11,15 +13,12 @@ X, y = make_regression(n_samples=20, n_features=2, noise=0.1, random_state=42) # Instantiate and train the estimator -model = LinearRegression() +model = Estimator() model.fit(X, y) # Make predictions predictions = model.predict(X) print("Predictions:", predictions, type(predictions)) -print("coef", model.coef_, type(model.coef_)) - with using_output_type("cupy"): assert isinstance(model.coef_, cp.ndarray) - assert isinstance(model.coef_, np.ndarray) From d6ba7cf0e1ce00f0076118e1edbd15f448726f62 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Tue, 21 Jan 2025 22:43:20 +0000 Subject: [PATCH 21/22] fixup! Implement logistic regression. --- python/cuml/cuml/internals/type_reflection.py | 29 +++++++------------ .../cuml/linear_model/logistic_regression.pyx | 27 +++++++++-------- python/cuml/cuml/solvers/qn.pyx | 4 +-- 3 files changed, 26 insertions(+), 34 deletions(-) diff --git a/python/cuml/cuml/internals/type_reflection.py b/python/cuml/cuml/internals/type_reflection.py index 47ca60fc91..bf26b5e88d 100644 --- a/python/cuml/cuml/internals/type_reflection.py +++ b/python/cuml/cuml/internals/type_reflection.py @@ -3,6 +3,10 @@ from cuml.internals.global_settings import GlobalSettings from cuml.internals.array import CumlArray from cuml.internals.memory_utils import using_output_type +from cuml.internals.input_utils import ( + determine_array_type, + determine_array_dtype, +) import inspect from collections.abc import Sequence @@ -47,15 +51,16 @@ def as_cuml_array(X, dtype=None) -> CumlArray: class CumlArrayDescriptor: - def __init__(self, order="K"): + def __init__(self, order="K", dtype=None): self.order = order + self.dtype = dtype def __set_name__(self, owner, name): self.name = name def __set__(self, obj, value): # Save the provided value as CumlArray and initialize output cache. - dtype = _get_dtype(obj) + dtype = self.dtype or _get_dtype(obj) setattr( obj, f"_{self.name}_data", @@ -108,22 +113,6 @@ def __delete__(self, obj): # Type reflection -def determine_array_type(value) -> str: - """Utility function to identify the array type.""" - if isinstance(value, CumlArray): - return "cuml" - elif isinstance(value, np.ndarray): - return "numpy" - elif isinstance(value, cp.ndarray): - return "cupy" - else: - return ValueError(f"Unknown array type: {type(value)}") - - -def determine_array_dtype(value): - return value.dtype - - def _set_output_type(obj: Any, output_type: str): setattr(obj, "_output_type", output_type) @@ -168,6 +157,10 @@ def inner(obj, *args, **kwargs): if isinstance(self.to, TypeOfArgument): arg_value = bound_args.arguments.get(self.to.argument_name) arg_type = determine_array_type(arg_value) + if arg_type is None: + raise TypeError( + f"Argument for {self.to.argument_name} must be array-like." + ) dtype = self.dtype or determine_array_dtype(arg_value) _set_output_type(obj, arg_type) _set_dtype(obj, dtype) diff --git a/python/cuml/cuml/linear_model/logistic_regression.pyx b/python/cuml/cuml/linear_model/logistic_regression.pyx index 44f1f51825..51da5fe845 100644 --- a/python/cuml/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/cuml/linear_model/logistic_regression.pyx @@ -22,7 +22,6 @@ from cuml.internals.safe_imports import cpu_only_import from cuml.internals.safe_imports import gpu_only_import import pprint -import cuml.internals from cuml.solvers import QN from cuml.internals.base import UniversalBase from cuml.internals.mixins import ClassifierMixin, FMajorInputTagMixin, SparseInputTagMixin @@ -188,7 +187,7 @@ class LogisticRegression(UniversalBase, """ _cpu_estimator_import_path = 'sklearn.linear_model.LogisticRegression' - classes_ = CumlArrayDescriptor(order='F') + classes_ = CumlArrayDescriptor(order='F', dtype="= self._num_classes or self.classes_[i] != c: - msg = "Class label {} not present.".format(c) - raise ValueError(msg) + np_classes = self.classes_.to_host_array() + for c in self.expl_spec_weights_.to_host_array(): + i = np.searchsorted(np_classes, c) + if i >= self._num_classes or np_classes[i] != c: + msg = "Class label {} not present.".format(c) + raise ValueError(msg) if self.class_weight is not None: if self.class_weight == 'balanced': @@ -429,11 +428,11 @@ class LogisticRegression(UniversalBase, """ Predicts the class probabilities for each class in X """ - return self._predict_proba_impl( + return CumlArray(self._predict_proba_impl( X, convert_dtype=convert_dtype, log_proba=False - ) + )) @generate_docstring(X='dense_sparse', return_values={'name': 'preds', @@ -448,11 +447,11 @@ class LogisticRegression(UniversalBase, Predicts the log class probabilities for each class in X """ - return self._predict_proba_impl( + return CumlArray(self._predict_proba_impl( X, convert_dtype=convert_dtype, log_proba=True - ) + )) def _predict_proba_impl(self, X, diff --git a/python/cuml/cuml/solvers/qn.pyx b/python/cuml/cuml/solvers/qn.pyx index a30fb7bc2e..96ad86135b 100644 --- a/python/cuml/cuml/solvers/qn.pyx +++ b/python/cuml/cuml/solvers/qn.pyx @@ -948,9 +948,9 @@ class QN(Base, _num_classes = self.get_num_classes(_num_classes_dim) if _num_classes == 2: - self.intercept_ = CumlArray.zeros(shape=1) + self.intercept_ = CumlArray.zeros(shape=1, dtype=self.dtype) else: - self.intercept_ = CumlArray.zeros(shape=_num_classes) + self.intercept_ = CumlArray.zeros(shape=_num_classes, dtype=self.dtype) @classmethod def _get_param_names(cls): From 47eeff7a162798967c79365170fa0cbd8af3c4e9 Mon Sep 17 00:00:00 2001 From: Simon Adorf Date: Tue, 21 Jan 2025 22:48:01 +0000 Subject: [PATCH 22/22] fixup! Implement logistic regression. --- python/cuml/cuml/linear_model/logistic_regression.pyx | 1 - 1 file changed, 1 deletion(-) diff --git a/python/cuml/cuml/linear_model/logistic_regression.pyx b/python/cuml/cuml/linear_model/logistic_regression.pyx index 51da5fe845..6ef481f249 100644 --- a/python/cuml/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/cuml/linear_model/logistic_regression.pyx @@ -188,7 +188,6 @@ class LogisticRegression(UniversalBase, _cpu_estimator_import_path = 'sklearn.linear_model.LogisticRegression' classes_ = CumlArrayDescriptor(order='F', dtype="