Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time windows in statistics #2948

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -869,17 +869,20 @@ verify_one_way_stat_func_errors(tsk_treeseq_t *ts, one_way_sample_stat_method *m
typedef int one_way_sample_stat_method_tw(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows,
tsk_size_t num_time_windows, const double *time_windows,
tsk_flags_t options, double *result);
tsk_size_t num_time_windows, const double *time_windows, tsk_flags_t options,
double *result);

// Temporary duplicate for time-windows-having methods
static void
verify_one_way_stat_func_errors_tw(tsk_treeseq_t *ts, one_way_sample_stat_method_tw *method)
verify_one_way_stat_func_errors_tw(
tsk_treeseq_t *ts, one_way_sample_stat_method_tw *method)
{
int ret;
tsk_id_t num_nodes = (tsk_id_t) tsk_treeseq_get_num_nodes(ts);
tsk_id_t samples[] = { 0, 1, 2, 3 };
tsk_size_t sample_set_sizes = 4;
double windows[] = { 0, 0, 0 };
double time_windows[] = { 0, 0, 0 };
double result;

ret = method(ts, 0, &sample_set_sizes, samples, 0, NULL, 0, NULL, 0, &result);
Expand Down Expand Up @@ -918,6 +921,15 @@ verify_one_way_stat_func_errors_tw(tsk_treeseq_t *ts, one_way_sample_stat_method

ret = method(ts, 1, &sample_set_sizes, samples, 2, windows, 0, NULL, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

/* Time window errors */
ret = method(
ts, 1, &sample_set_sizes, samples, 0, NULL, 0, time_windows, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS);

ret = method(
ts, 1, &sample_set_sizes, samples, 0, NULL, 2, time_windows, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);
}

static void
Expand Down Expand Up @@ -1258,23 +1270,24 @@ verify_afs(tsk_treeseq_t *ts)
sample_set_sizes[0] = n - 2;
sample_set_sizes[1] = 2;
ret = tsk_treeseq_allele_frequency_spectrum(
ts, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, 0, result);
ts, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, 0, result);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh-oh, are these tabs?

CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_allele_frequency_spectrum(
ts, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result);
ts, 2, sample_set_sizes, samples, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_allele_frequency_spectrum(ts, 2, sample_set_sizes, samples, 0,
NULL, 0, NULL, TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE, result);
NULL, 0, NULL, TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_allele_frequency_spectrum(ts, 2, sample_set_sizes, samples, 0,
NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE, result);
NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_POLARISED | TSK_STAT_SPAN_NORMALISE,
result);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_allele_frequency_spectrum(ts, 2, sample_set_sizes, samples, 0,
NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE, result);
NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);

free(result);
Expand Down
116 changes: 54 additions & 62 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -1233,8 +1233,7 @@
}

static int
tsk_treeseq_check_time_windows(tsk_size_t num_windows,
const double *windows)
tsk_treeseq_check_time_windows(tsk_size_t num_windows, const double *windows)
{
int ret = TSK_ERR_BAD_WINDOWS;
tsk_size_t j;
Expand All @@ -1245,10 +1244,11 @@
}

if (windows[0] < 0) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if currently the code assumes this is 0, should check for == here

goto out;
goto out;

Check warning on line 1247 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L1247

Added line #L1247 was not covered by tests
}
if (windows[num_windows] > INFINITY) {
goto out;

if (windows[0] != 0) {
goto out;

Check warning on line 1251 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L1251

Added line #L1251 was not covered by tests
}

for (j = 0; j < num_windows; j++) {
Expand All @@ -1256,7 +1256,7 @@
goto out;
}
}
ret = 0;

Check warning on line 1259 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L1259

Added line #L1259 was not covered by tests
out:
return ret;
}
Expand Down Expand Up @@ -3513,35 +3513,21 @@
return ret;
}

#define MAX(a,b) ((a) > (b) ? (a) : (b))
#define MIN(a,b) ((a) < (b) ? (a) : (b))

/* int getValue_nDimensions( int * baseAddress, int * indexes, int nDimensions ) { */
/* int i; */
/* int offset = 0; */
/* for( i = 0; i < nDimensions; i++ ) { */
/* offset += pow(LEN,i) * indexes[nDimensions - (i + 1)]; */
/* } */

/* return *(baseAddress + offset); */
/* } */

static int TSK_WARN_UNUSED
tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double right,
double *restrict last_update,
const double *restrict time, tsk_id_t *restrict parent, const double *time_windows,
const double *counts, tsk_size_t num_sample_sets,
tsk_size_t num_time_windows, tsk_size_t window_index, tsk_size_t time_window_index,
const tsk_size_t *result_dims, tsk_flags_t options, double *result)
double *restrict last_update, const double *restrict time, tsk_id_t *restrict parent,
const double *time_windows, const double *counts, tsk_size_t num_sample_sets,
tsk_size_t num_time_windows, tsk_size_t window_index, const tsk_size_t *result_dims,
tsk_flags_t options, double *result)
{
int ret = 0;
tsk_size_t afs_size;
tsk_size_t k;
tsk_size_t time_window_index;
double *afs;
tsk_size_t *coordinate = tsk_malloc(num_sample_sets * sizeof(*coordinate));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gee, wouldn't it be better to malloc this outside this function, and pass it in? (I honestly don't know...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HA! IDK, I'll let it there for now, but yeah maybe!

bool polarised = !!(options & TSK_STAT_POLARISED);
const double *count_row = GET_2D_ROW(counts, num_sample_sets + 1, u);
/* double x = (right - last_update[u]) * branch_length[u]; */
double x = 0;
double t_v = 0;
double tw_branch_length = 0;
Expand All @@ -3550,23 +3536,31 @@
ret = TSK_ERR_NO_MEMORY;
goto out;
}
if (parent[u] != -1){
t_v = time[parent[u]];
if (0 < all_samples && all_samples < self->num_samples) {
for (time_window_index = 0; time_window_index < num_time_windows; time_window_index++){
afs_size = result_dims[num_sample_sets];
afs = result + afs_size * (window_index * num_time_windows + time_window_index);
for (k = 0; k < num_sample_sets; k++) {
coordinate[k] = (tsk_size_t) count_row[k];
}
if (!polarised){
fold(coordinate, result_dims, num_sample_sets);
}
tw_branch_length = MIN(time_windows[time_window_index + 1], t_v) - MAX(time_windows[0], time[u]);
x = (right - last_update[u]) * tw_branch_length;
increment_nd_array_value(afs, num_sample_sets, result_dims, coordinate, x);
}
}
if (parent[u] != TSK_NULL) {
t_v = time[parent[u]];
if (0 < all_samples && all_samples < self->num_samples) {
time_window_index = 0;
while (time_window_index < num_time_windows
&& time_windows[time_window_index] < t_v) {
/* for (time_window_index = 0; time_window_index < num_time_windows;
* time_window_index++){ */
afs_size = result_dims[num_sample_sets];
afs = result
+ afs_size * (window_index * num_time_windows + time_window_index);
for (k = 0; k < num_sample_sets; k++) {
coordinate[k] = (tsk_size_t) count_row[k];
}
if (!polarised) {
fold(coordinate, result_dims, num_sample_sets);
}
tw_branch_length = TSK_MIN(time_windows[time_window_index + 1], t_v)
- TSK_MAX(time_windows[0], time[u]);
x = (right - last_update[u]) * tw_branch_length;
increment_nd_array_value(
afs, num_sample_sets, result_dims, coordinate, x);
time_window_index++;
}
}
}
last_update[u] = right;
out:
Expand All @@ -3582,7 +3576,7 @@
{
int ret = 0;
tsk_id_t u, v;
tsk_size_t window_index, time_window_index;
tsk_size_t window_index;
tsk_size_t num_nodes = self->tables->nodes.num_rows;
const tsk_id_t num_edges = (tsk_id_t) self->tables->edges.num_rows;
const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order;
Expand Down Expand Up @@ -3616,26 +3610,23 @@
tk = 0;
t_left = 0;
window_index = 0;
time_window_index = 0;
while (tj < num_edges || t_left < sequence_length) {
tsk_bug_assert(window_index < num_windows);
while (tk < num_edges && edge_right[O[tk]] == t_left) {
h = O[tk];
tk++;
u = edge_child[h];
v = edge_parent[h];
ret = tsk_treeseq_update_branch_afs(self, u, t_left,
last_update, node_time, parent, time_windows, counts, num_sample_sets,
num_time_windows, window_index, time_window_index,
result_dims, options, result);
ret = tsk_treeseq_update_branch_afs(self, u, t_left, last_update, node_time,
parent, time_windows, counts, num_sample_sets, num_time_windows,
window_index, result_dims, options, result);
if (ret != 0) {
goto out;
}
while (v != TSK_NULL) {
ret = tsk_treeseq_update_branch_afs(self, v, t_left,
last_update, node_time, parent, time_windows, counts,
num_sample_sets, num_time_windows, window_index,
time_window_index, result_dims, options, result);
ret = tsk_treeseq_update_branch_afs(self, v, t_left, last_update,
node_time, parent, time_windows, counts, num_sample_sets,
num_time_windows, window_index, result_dims, options, result);
if (ret != 0) {
goto out;
}
Expand All @@ -3654,10 +3645,9 @@
parent[u] = v;
branch_length[u] = node_time[v] - node_time[u];
while (v != TSK_NULL) {
ret = tsk_treeseq_update_branch_afs(self, v, t_left,
last_update, node_time, parent, time_windows, counts,
num_sample_sets, num_time_windows, window_index,
time_window_index, result_dims, options, result);
ret = tsk_treeseq_update_branch_afs(self, v, t_left, last_update,
node_time, parent, time_windows, counts, num_sample_sets,
num_time_windows, window_index, result_dims, options, result);
if (ret != 0) {
goto out;
}
Expand All @@ -3679,10 +3669,9 @@
/* Flush the contributions of all nodes to the current window */
for (u = 0; u < (tsk_id_t) num_nodes; u++) {
tsk_bug_assert(last_update[u] < w_right);
ret = tsk_treeseq_update_branch_afs(self, u, w_right,
last_update, node_time, parent, time_windows, counts,
num_sample_sets, num_time_windows, window_index,
time_window_index, result_dims, options, result);
ret = tsk_treeseq_update_branch_afs(self, u, w_right, last_update,
node_time, parent, time_windows, counts, num_sample_sets,
num_time_windows, window_index, result_dims, options, result);
if (ret != 0) {
goto out;
}
Expand Down Expand Up @@ -3755,8 +3744,12 @@
num_time_windows = 1;
time_windows = default_time_windows;
} else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this line is probably the right place to check if it's mode="site" and throw an error?

ret = tsk_treeseq_check_time_windows(
num_time_windows, time_windows);
if (stat_site
&& tsk_memcmp(time_windows, default_time_windows, sizeof(double)) != 0) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this is a bit awkward - what if instead we used num_time_windows=0 to mean "default/no time windows"?

Copy link
Collaborator

@tforest tforest Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But time_windows are always initialized by default as [0, inf], so num_time_windows=2, comparing to the default was the clearest I found for now. But maybe the problem lies in the initialization caused by the parsing of the windows in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wait - we're already in the else clause where we know that time_windows != NULL. So I don't think we need to check this at all - just throw the error?

Suggested change
&& tsk_memcmp(time_windows, default_time_windows, sizeof(double)) != 0) {
) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if someone explicitly specifies time_windows = [0, np.inf], mode="node" then they'll get the error, but that's okay - the error message says "you can't specify time windows", not "time windows must be [0, Inf)" (I think that's what it says anyhow).

ret = TSK_ERR_UNSUPPORTED_STAT_MODE;
goto out;

Check warning on line 3750 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L3749-L3750

Added lines #L3749 - L3750 were not covered by tests
}
ret = tsk_treeseq_check_time_windows(num_time_windows, time_windows);
if (ret != 0) {
goto out;
}
Expand Down Expand Up @@ -3796,7 +3789,6 @@
count_row[num_sample_sets] = 1;
}
result_dims[num_sample_sets] = (tsk_size_t) afs_size;
// Initiate memory for result array
tsk_memset(result, 0, num_windows * num_time_windows * afs_size * sizeof(*result));

if (stat_site) {
Expand Down
15 changes: 7 additions & 8 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9077,7 +9077,7 @@
npy_intp *shape;

windows_array = (PyArrayObject *) PyArray_FROMANY(
windows, NPY_FLOAT64, 1, 1, NPY_ARRAY_IN_ARRAY);
windows, NPY_FLOAT64, 1, 1, NPY_ARRAY_IN_ARRAY);
if (windows_array == NULL) {
goto out;
}
Expand All @@ -9095,7 +9095,6 @@
return ret;
}


static PyArrayObject *
TreeSequence_allocate_results_array(
TreeSequence *self, tsk_flags_t mode, tsk_size_t num_windows, tsk_size_t output_dim)
Expand Down Expand Up @@ -9440,13 +9439,13 @@
TreeSequence *self, PyObject *args, PyObject *kwds)
{
PyObject *ret = NULL;
static char *kwlist[] = { "sample_set_sizes", "sample_sets", "windows", "time_windows", "mode",
"span_normalise", "polarised", NULL };
static char *kwlist[] = { "sample_set_sizes", "sample_sets", "windows",
"time_windows", "mode", "span_normalise", "polarised", NULL };
PyObject *sample_set_sizes = NULL;
PyObject *sample_sets = NULL;
PyObject *windows = NULL;
PyObject *time_windows = NULL;
char *mode = "NULL";
char *mode = NULL;
PyArrayObject *sample_set_sizes_array = NULL;
PyArrayObject *sample_sets_array = NULL;
PyArrayObject *windows_array = NULL;
Expand All @@ -9463,7 +9462,7 @@
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|sii", kwlist, &sample_set_sizes,
&sample_sets, &windows, &time_windows, &mode, &span_normalise, &polarised)) {
&sample_sets, &windows, &time_windows, &mode, &span_normalise, &polarised)) {
goto out;
}
if (parse_stats_mode(mode, &options) != 0) {
Expand All @@ -9484,7 +9483,7 @@
goto out;
}
if (parse_windows(time_windows, &time_windows_array, &num_time_windows) != 0) {
goto out;

Check warning on line 9486 in python/_tskitmodule.c

View check run for this annotation

Codecov / codecov/patch

python/_tskitmodule.c#L9486

Added line #L9486 was not covered by tests
}
shape = PyMem_Malloc((num_sample_sets + 1 + 1) * sizeof(*shape));
if (shape == NULL) {
Expand All @@ -9496,8 +9495,8 @@
for (k = 0; k < num_sample_sets; k++) {
shape[k + 1 + 1] = 1 + sizes[k];
}
result_array
= (PyArrayObject *) PyArray_SimpleNew(1 + 1 + num_sample_sets, shape, NPY_FLOAT64);
result_array = (PyArrayObject *) PyArray_SimpleNew(
1 + 1 + num_sample_sets, shape, NPY_FLOAT64);
if (result_array == NULL) {
goto out;
}
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,7 +2263,7 @@ def test_output_dims(self):
n = len(samples)
time_windows = [0, np.inf]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be testing with more than one time window as well here, probably - adding another for to the loops below will make a lot of combinations, but only a factor of 2 more? (with 1 windows and then with 2 windows, maybe)


for mode in ["site", "branch"]:
for mode in ["branch"]:
for s in [[n], [n - 2, 2], [n - 4, 2, 2], [1] * n]:
s = np.array(s, dtype=np.uint32)
windows = [0, L]
Expand Down
7 changes: 3 additions & 4 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7637,7 +7637,6 @@
# Note: need to make sure windows is a string or we try to compare the
# target with a numpy array elementwise.
if windows is None:
# initiate default spanning windows
windows = [0.0, self.sequence_length]
elif isinstance(windows, str):
if windows == "trees":
Expand Down Expand Up @@ -7684,6 +7683,8 @@
stat = stat[0, :, :]
elif strip_timewin:
stat = stat[:, 0, :]
elif strip_win and strip_timewin:
stat = stat[0, 0, :]

Check warning on line 7687 in python/tskit/trees.py

View check run for this annotation

Codecov / codecov/patch

python/tskit/trees.py#L7687

Added line #L7687 was not covered by tests
return stat

def __one_way_sample_set_stat(
Expand Down Expand Up @@ -7747,7 +7748,7 @@
polarised=False,
):
if sample_sets is None:
sample_sets = self.samples()

Check warning on line 7751 in python/tskit/trees.py

View check run for this annotation

Codecov / codecov/patch

python/tskit/trees.py#L7751

Added line #L7751 was not covered by tests

# First try to convert to a 1D numpy array. If it is, then we strip off
# the corresponding dimension from the output.
Expand All @@ -7760,14 +7761,14 @@
# If we've successfully converted sample_sets to a 1D numpy array
# of integers then drop the dimension
if len(sample_sets.shape) == 1:
sample_sets = [sample_sets]
drop_dimension = True

Check warning on line 7765 in python/tskit/trees.py

View check run for this annotation

Codecov / codecov/patch

python/tskit/trees.py#L7764-L7765

Added lines #L7764 - L7765 were not covered by tests

sample_set_sizes = np.array(
[len(sample_set) for sample_set in sample_sets], dtype=np.uint32
)
if np.any(sample_set_sizes == 0):
raise ValueError("Sample sets must contain at least one element")

Check warning on line 7771 in python/tskit/trees.py

View check run for this annotation

Codecov / codecov/patch

python/tskit/trees.py#L7771

Added line #L7771 was not covered by tests

flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32)
stat = self.__run_windowed_stat_tw(
Expand All @@ -7781,12 +7782,10 @@
polarised=polarised,
)
if drop_dimension:
stat = stat.reshape(stat.shape[:-1])

Check warning on line 7785 in python/tskit/trees.py

View check run for this annotation

Codecov / codecov/patch

python/tskit/trees.py#L7785

Added line #L7785 was not covered by tests
# TODO: Write test for this
if (stat.shape == () and windows is None) or (
stat.shape == () and time_windows is None
):
if stat.shape == () and windows is None and time_windows is None:
stat = stat[()]

Check warning on line 7788 in python/tskit/trees.py

View check run for this annotation

Codecov / codecov/patch

python/tskit/trees.py#L7788

Added line #L7788 was not covered by tests
return stat

def parse_sites(self, sites):
Expand Down
Loading