Filter Table by max within group

Probably simple but I have spent way too long on this. I have a table and would like to “de-duplicate” the rows keeping only row for each group in id2 based on max value of val. I manage with some gnarly group_by() -> hl.agg.max() -> join() but there must be a neater way. Essentially, I would like to turn this table:

id1 id2 val more_columns
-----------
a x 0.1 ...
b x 0.3 ...
c x 0.4 ...
d y 0.2 ...
e y 0.9 ...
f y 0.5 ...

Into this:

id1 id2 val more_columns
-------------
c x 0.4 ...
e y 0.9 ...

I hope it makes sense and thank you for your time

I think the group_by is exactly the right thing:

rows = mt.rows()
to_keep = rows.group_by(rows.id2).aggregate(top1 = hl.agg.take(rows.id1, ordering=-rows.id2)[0])).key_by('top1')
mt = mt.semi_join_rows(to_keep)
1 Like

Small correction for those who come looking in the future:

rows = mt.rows()
to_keep = rows.group_by(rows.id2).aggregate(top1 = hl.agg.take(rows.id1, 1, ordering=-rows.val)[0])).key_by('top1')
mt = mt.semi_join_rows(to_keep)
# or 
# mt = mt.filter_rows(hl.is_defined(to_keep[mt.id1]))

oops, thanks! Was it just the 1 argument to take that I missed?

And ordering=-rows.val I think.

thanks for the awesome information.

1 Like

thanks my issue has been fixed.