Filtering using agg.stats without collecting to local value

Hi,
I’m trying to perform the very simple task of filtering on a column by excluding samples > 3 standard deviations from the mean. I’m using values from hl.agg.stats() but would like to do this without saving these values to a local variable with mt.aggregate_cols(), because I have numerous filtering steps preceding this one, and do not want to trigger their execution until the very end. I’ve generated two sets of code that in my mind should produce the same results, but do not.

1. Without storing stats to local variable

het_stats = hl.agg.stats(mt_res.c_het_val)
het_upper = het_stats.mean + (3 * het_stats.stdev)
het_lower = het_stats.mean - (3 * het_stats.stdev)

mt_test = mt_res.filter_cols(
(mt_res.c_het_val < het_upper) &
(mt_res.c_het_val > het_lower)
)
mt_test.count_cols()
684

2. Storing stats to local variable

het_stats = mt_res.aggregate_cols(hl.agg.stats(mt_res.c_het_val))
het_upper = het_stats.mean + (3 * het_stats.stdev)
het_lower = het_stats.mean - (3 * het_stats.stdev)

mt_test = mt_res.filter_cols(
(mt_res.c_het_val < het_upper) &
(mt_res.c_het_val > het_lower)
)
mt_test.count_cols()
1422

The correct output is produced by the code that stores the stats to a local variable. I suspect the issue has to do with what exactly hl.agg.stats() is using as input, but I can’t figure out how to make it act as expected in the first case. Any insight would be greatly appreciated.

You’re hitting against one of the subtlest pieces of Hail’s aggregator interface, but also the reason why aggregators in Hail are so flexible – that aggregations derive their meaning from context.

Suppose I have the following:

hl.agg.stats(mt.DP) # stats of an entry field

What is its value? The answer is that there’s not enough information here. We can use this in a variety of places:

# global stats
mt.aggregate_entries(hl.agg.stats(mt.DP))

# aggregate over entries per row
mt = mt.annotate_rows(dp_stats = hl.agg.stats(mt.DP))

# aggregate over entries per column
mt = mt.annotate_cols(dp_stats = hl.agg.stats(mt.DP))

# aggregate over entries per row per column group
mt2 = mt.group_by(mt.pop).aggregate(dp_stats = hl.agg.stats(mt.DP))

...etc...

So you can’t just use hl.agg.stats(mt_res.c_het_val) on its own, it needs to be inside something that defines its context (aggregate_cols is the right context here). We’ve wrestled with the right interface for making that lazy, but right now there’s a protected flag on aggregate_cols:

het_stats = mt_res.aggregate_cols(hl.agg.stats(mt_res.c_het_val), _localize=False)

I see, thank you for the very detailed, clear and quick response! I figured context had something to do with it. I thought that using the aggregator within the context of filter_cols() would provide the right context, but I guess not. So if I understand correctly, I can add the “_localize=False” flag to aggregate_cols() and it will not trigger collection? Out of curiosity does this also work for aggregate_rows()? I will likely run into similar issues there. Thanks!

Yes, _localize=False works on aggregate_cols, aggregate_rows, aggregate_entries, Table.aggregate, and Table.collect, I think.

Great, this is working as expected. Thanks so much for your help! Perhaps it would be helpful to add this flag to the documentation, as I imagine there are others who would find it useful as well.

Agreed. Will add a note to do that. We haven’t thought of a better name for the param in the ~18 months we’ve had this functionality!

1 Like