diff --git a/quantile_forest/_quantile_forest.py b/quantile_forest/_quantile_forest.py index b8dbe0c..2ab29f7 100755 --- a/quantile_forest/_quantile_forest.py +++ b/quantile_forest/_quantile_forest.py @@ -127,9 +127,7 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False): sample_weight : array-like of shape (n_samples,), default=None Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. In the case of - classification, splits are also ignored if they would result in any - single class carrying a negative weight in either child node. + ignored while searching for a split in each node. sparse_pickle : bool, default=False Pickle the underlying data structure using a SciPy sparse matrix. @@ -230,12 +228,12 @@ def _get_y_train_leaves_slice( X_leaves_bootstrap : array-like of shape (n_samples,) Leaf node indices of the bootstrap training samples. - sample_weight : array-like of shape (n_samples,), default=None + sample_weight : array-like of shape (n_samples, n_outputs), \ + default=None Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. In the case of - classification, splits are also ignored if they would result in any - single class carrying a negative weight in either child node. + ignored while searching for a split in each node. For each output, + the ordering of the weights correspond to the sorted samples. leaf_subsample : bool Subsample leaf nodes. If True, leaves are randomly sampled to size @@ -261,6 +259,9 @@ def _get_y_train_leaves_slice( """ n_outputs = bootstrap_indices.shape[1] + if sample_weight is not None: + sample_weight = np.squeeze(sample_weight) + shape = (max_node_count, n_outputs, max_samples_leaf) y_train_leaves_slice = np.zeros(shape, dtype=np.int64) @@ -319,10 +320,12 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None): The indices that would sort the target values in ascending order. Used to associate ``est.apply`` outputs with sorted target values. - sample_weight : array-like of shape (n_samples,), default=None + sample_weight : array-like of shape (n_samples, n_outputs), \ + default=None Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. + ignored while searching for a split in each node. For each output, + the ordering of the weights correspond to the sorted samples. Returns ------- @@ -394,9 +397,6 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None): if sample_count > max_samples_leaf: max_samples_leaf = sample_count - if sample_weight is not None: - sample_weight = np.squeeze(sample_weight) - y_train_leaves = [ self._get_y_train_leaves_slice( bootstrap_indices[:, i],