-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compute noise_variance_
in PCA implementation
#6234
base: branch-25.02
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/* | ||
* Copyright (c) 2018-2024, NVIDIA CORPORATION. | ||
* Copyright (c) 2018-2025, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
|
@@ -41,6 +41,7 @@ void truncCompExpVars(const raft::handle_t& handle, | |
math_t* components, | ||
math_t* explained_var, | ||
math_t* explained_var_ratio, | ||
math_t* noise_vars, | ||
const paramsTSVDTemplate<enum_solver>& prms, | ||
cudaStream_t stream) | ||
{ | ||
|
@@ -67,6 +68,20 @@ void truncCompExpVars(const raft::handle_t& handle, | |
prms.n_components, | ||
std::size_t(1), | ||
stream); | ||
|
||
// Compute the scalar noise_vars defined as (pseudocode) | ||
// (n_components < min(n_cols, n_rows)) ? explained_var_all[n_components:].mean() : 0 | ||
if (prms.n_components < prms.n_cols && prms.n_components < prms.n_rows) { | ||
raft::stats::mean(noise_vars, | ||
explained_var_all.data() + prms.n_components, | ||
std::size_t{1}, | ||
prms.n_cols - prms.n_components, | ||
false, | ||
true, | ||
stream); | ||
} else { | ||
raft::matrix::setValue(noise_vars, noise_vars, math_t{0}, 1, stream); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to use
|
||
} | ||
} | ||
|
||
/** | ||
|
@@ -116,7 +131,7 @@ void pcaFit(const raft::handle_t& handle, | |
raft::stats::cov( | ||
handle, cov.data(), input, mu, prms.n_cols, prms.n_rows, true, false, true, stream); | ||
truncCompExpVars( | ||
handle, cov.data(), components, explained_var, explained_var_ratio, prms, stream); | ||
handle, cov.data(), components, explained_var, explained_var_ratio, noise_vars, prms, stream); | ||
|
||
math_t scalar = (prms.n_rows - 1); | ||
raft::matrix::seqRoot(explained_var, singular_vals, scalar, n_components, stream, true); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/* | ||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. | ||
* Copyright (c) 2019-2025, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
|
@@ -69,7 +69,7 @@ void fit_impl(raft::handle_t& handle, | |
Stats::opg::cov(handle, cov, input_data, input_desc, mu_data, true, streams, n_streams); | ||
|
||
ML::truncCompExpVars<T, mg_solver>( | ||
handle, cov.ptr, components, explained_var, explained_var_ratio, prms, streams[0]); | ||
handle, cov.ptr, components, explained_var, explained_var_ratio, noise_vars, prms, streams[0]); | ||
|
||
T scalar = (prms.n_rows - 1); | ||
raft::matrix::seqRoot(explained_var, singular_vals, scalar, prms.n_components, streams[0], true); | ||
|
@@ -128,9 +128,6 @@ void fit_impl(raft::handle_t& handle, | |
streams, | ||
n_streams, | ||
verbose); | ||
for (std::uint32_t i = 0; i < n_streams; i++) { | ||
handle.sync_stream(streams[i]); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sync point was unneeded, it was already handled after exiting this branch below. |
||
} else if (prms.algorithm == mg_solver::QR) { | ||
const raft::handle_t& h = handle; | ||
cudaStream_t stream = h.get_stream(); | ||
|
@@ -194,6 +191,20 @@ void fit_impl(raft::handle_t& handle, | |
std::size_t(1), | ||
stream); | ||
|
||
// Compute the scalar noise_vars defined as (pseudocode) | ||
// (n_components < min(n_cols, n_rows)) ? explained_var_all[n_components:].mean() : 0 | ||
if (prms.n_components < prms.n_cols && prms.n_components < prms.n_rows) { | ||
raft::stats::mean(noise_vars, | ||
explained_var_all.data() + prms.n_components, | ||
std::size_t{1}, | ||
prms.n_cols - prms.n_components, | ||
false, | ||
true, | ||
stream); | ||
} else { | ||
raft::matrix::setValue(noise_vars, noise_vars, T{0}, 1, stream); | ||
} | ||
|
||
raft::linalg::transpose(vMatrix.data(), prms.n_cols, stream); | ||
raft::matrix::truncZeroOrigin( | ||
vMatrix.data(), prms.n_cols, components, prms.n_components, prms.n_cols, stream); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) 2019-2023, NVIDIA CORPORATION. | ||
# Copyright (c) 2019-2025, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
|
@@ -72,13 +72,13 @@ def test_pca_fit(datatype, input_type, name, use_handle): | |
"components_", | ||
"explained_variance_", | ||
"explained_variance_ratio_", | ||
"noise_variance_", | ||
]: | ||
with_sign = False if attr in ["components_"] else True | ||
print(attr) | ||
print(getattr(cupca, attr)) | ||
print(getattr(skpca, attr)) | ||
cuml_res = getattr(cupca, attr) | ||
|
||
skl_res = getattr(skpca, attr) | ||
assert array_equal(cuml_res, skl_res, 1e-3, with_sign=with_sign) | ||
|
||
|
@@ -304,6 +304,22 @@ def test_sparse_pca_inputs(nrows, ncols, whiten, return_sparse, cupy_input): | |
assert array_equal(i_sparse, X.todense(), 1e-1, with_sign=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_samples, n_features", | ||
[ | ||
pytest.param(9, 20, id="n_samples <= n_components"), | ||
pytest.param(20, 10, id="n_features <= n_components"), | ||
], | ||
) | ||
def test_noise_variance_zero(n_samples, n_features): | ||
X, _ = make_blobs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can a test other than zero be added? To check the correctness of the computation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is checking the case where the noise variance is defined as zero since there weren't enough samples or features. Other cases where it's non zero are already tested thoroughly above on line 75. |
||
n_samples=n_samples, n_features=n_features, random_state=0 | ||
) | ||
cupca = cuPCA(n_components=10) | ||
cupca.fit(X) | ||
assert cupca.noise_variance_.item() == 0 | ||
|
||
|
||
def test_exceptions(): | ||
with pytest.raises(NotFittedError): | ||
X = cp.random.random((10, 10)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use the newer mdspan-API of this
mean
function as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's my understanding that since the PCA code here doesn't create streams from the handle (#2470), and is mostly using the legacy APIs here that take streams instead of handles, that switching to the mdspan APIs instead would require a substantial refactor. I'm happy to take that on in a follow-up, but would prefer to port all the code to the newer APIs at the same time, rather than try to force it in as part of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's a valid point, let's keep that port for an other dedicated PR.