Group by columns and aggregate entries over all entries in the group

Here’s another issue I have been struggling with for a bit and I am sure there is some obvious way to do it, but can’t figure it out…

I have a Matrix Table, with samples in columns, and rows are variants and I want to group the samples by one of their fields (obesity) and run a counter over all entries for that group.

This is what I have so far:

t = hl.utils.range_matrix_table(10,6)
t = t.annotate_entries(n = hl.rand_bool(0.66))
t = t.annotate_cols(obesity = hl.array(['lean','obese'])[t.col_idx // 3])
t = t.annotate_rows(gene = hl.array(['BRCA1','KRAS'])[t.row_idx // 5])
t.describe()

----------------------------------------
Global fields:
    None
----------------------------------------
Column fields:
    'col_idx': int32
    'obesity': str
----------------------------------------
Row fields:
    'row_idx': int32
----------------------------------------
Entry fields:
    'n': bool
----------------------------------------
Column key: ['col_idx']
Row key: ['row_idx']
----------------------------------------

t.show(n_cols=6)

                0 1 2 3 4 5
row_idx    n n n n n n
int32	bool	bool	bool	bool	bool	bool
0	true	true	false	true	false	true
1	true	false	false	false	true	true
2	false	true	true	true	true	false
3	true	true	false	true	true	true
4	true	false	true	true	true	false
5	false	true	true	false	true	true
6	true	false	true	true	true	true
7	true	false	true	true	false	false
8	true	true	false	true	true	false
9	true	false	true	true	true	true

t.entries().show()

row_idx gene col_idx obesity n
int32	str	int32	str	bool
0	"BRCA1"	0	"lean"	true
0	"BRCA1"	1	"lean"	true
0	"BRCA1"	2	"lean"	true
0	"BRCA1"	3	"obese"	true
0	"BRCA1"	4	"obese"	false
0	"BRCA1"	5	"obese"	true
1	"BRCA1"	0	"lean"	false
1	"BRCA1"	1	"lean"	true
1	"BRCA1"	2	"lean"	true
1	"BRCA1"	3	"obese"	false

ta = t.group_cols_by(t.obesity)
ta = ta.aggregate_entries(c = hl.agg.counter(t.n)).result()

ta.show()

                'lean'                        'obese'
row_idx  c.  c
int32	dict<bool, int64>	dict<bool, int64>
0	        {false:1},true:2}}	{false:1},true:2}}
1	        {false:2},true:1}}	{false:1},true:2}}
2	        {false:1},true:2}}	{false:1},true:2}}
3	        {false:1},true:2}}	{true:3}}
4	        {false:1},true:2}}	{false:1},true:2}}
5	        {false:1},true:2}}	{false:1},true:2}}
6	        {false:1},true:2}}	{true:3}}
7	        {false:1},true:2}}	{false:2},true:1}}
8	        {false:1},true:2}}	{false:1},true:2}}
9	        {false:1},true:2}}	{true:3}}

Surely I must be able to aggregate over all rows, rather than aggregate only PER row?

I think I understand that I need some sort of column annotation for the “lean” and" “obese” columns and since I want to collapse all the rows, I would not have any rows “left over” so somehow I need to loose the rows, but that probably is not the way to think about it…

I could collapse the rows based on some row field like gene, but you cannot seem to group_by BOTH over columns AND rows, so…

Any ideas?

I think I found the way to do this and it again involves explode!

ta = t.group_cols_by(t.obesity).aggregate(c = hl.agg.collect(t.n))
ta = ta.group_rows_by(ta.gene).aggregate(bla = hl.agg.explode(lambda element: hl.agg.counter(element), ta.c))

        'lean' 'obese'
gene     bla    bla
str	dict<bool, int64>	dict<bool, int64>
"BRCA1"	{false:4},true:11}}	{false:6},true:9}}
"KRAS"	{false:5},true:10}}	{false:7},true:8}}

I DO have the feeling that COLLECT and EXPLODE are expensive actions, so if that is the case and if there are other options, let me know…

Explode isn’t necessarily expensive, but collect definitely can be if it’s materializing a large collection. In this case, you’re collecting a length num_rows array.

You’re right we don’t make grouping by both rows and columns easy. Here is the best way I found:

> tg = tg = t.group_rows_by(t.gene).aggregate(count_true = hl.agg.count_where(t.n), count_all = hl.agg.count())
> tg.describe()
----------------------------------------
Global fields:
    None
----------------------------------------
Column fields:
    'col_idx': int32
    'obesity': str
----------------------------------------
Row fields:
    'gene': str
----------------------------------------
Entry fields:
    'count_true': int64
    'count_all': int64
----------------------------------------
Column key: ['col_idx']
Row key: ['gene']
----------------------------------------

> ta = tg.select_rows(counts = hl.agg.group_by(tg.obesity, hl.struct(true=hl.agg.sum(tg.count_true), all=hl.agg.sum(tg.count_all)))).rows()
> ta.describe()
----------------------------------------
Global fields:
    None
----------------------------------------
Row fields:
    'gene': str
    'counts': dict<str, struct {
        true: int64,
        all: int64
    }>
----------------------------------------
Key: ['gene']
----------------------------------------

> ta.show()
2021-08-30 09:45:59 Hail: INFO: Coerced sorted dataset
+---------+--------------------------------------------+
| gene    | counts                                     |
+---------+--------------------------------------------+
| str     | dict<str, struct{true: int64, all: int64}> |
+---------+--------------------------------------------+
| "BRCA1" | {"lean":(11,15),"obese":(10,15)}           |
| "KRAS"  | {"lean":(12,15),"obese":(13,15)}           |
+---------+--------------------------------------------+
1 Like