Variant call rate filtering by group

Hello, I’m trying to filter my variants by call rate. I’d like to filter out variants that have a < 90% call rate in any of the 3 cohorts that were jointly called.

Here is the code that I have started with, but I could use some help getting it to run properly:

First I confirmed that my hand calculated variant call rate produces the same results as the call rate calculated by variant_qc (it does):

mt = mt.annotate_rows(call_rate_trial = mt.variant_qc.n_called / mt.count_cols())
filter_condition = (mt.call_rate_trial < 0.90)
print(mt.aggregate_entries(hl.agg.counter(filter_condition)))

Next I tried unsuccessfully to do the same thing grouped by cohort. The first line doesn’t run:

mt = mt.annotate_rows(call_rate_trial_test = hl.group_by(mt.TRIAL, mt.variant_qc.n_called / mt.count_cols()))
filter_condition = (mt.call_rate_trial_test < 0.90)
print(mt.aggregate_entries(hl.agg.counter(filter_condition)))

Could anyone help me with how to do this grouped by cohort?