Querying variants by genotype counts for two cohorts

I want to do a query which returns a set of variants for which at least half of the sample genotypes in the control cohort is homozygous reference and at least half of the sample genotypes in the cases cohort is heterozygous or homozygous alternate. I made the following query but I had to create a different MT and then query the second MT.

Suppose I have an MT

locus alleles S1 (GT) S2 (GT) S3 (GT) S4 (GT) S3 (GT)
2:11944 [“T”, “C”] 1/1 1/0 0/0 0/0 0/0
2:16374 [“A”, “G”] 0/0 0/0 1/0 1/0 1/0
2:22422 [“T”, “A”] 1/0 1/1 1/0 0/0 0/0
controls = ['S4', 'S5']
mt = hl.read_matrix_table('test.mt')

mt = mt.annotate_cols(
    cohort=hl.if_else(hl.literal(set(controls)).contains(mt.s), 'Control', 'Patient'))

mt2 = mt.group_cols_by(mt.cohort).aggregate(hom_rate=hl.agg.fraction(mt.GT.is_hom_ref()))

mt2.filter_rows(hl.agg.all(((
    (mt2.hom_rate > 0.5) &
    (mt2.cohort=='Control')
)|
(
    (mt2.hom_rate < 0.5) & 
    (mt2.cohort=='Patient')
)))).show()

I would get the output table as

locus alleles Control Patient
2:11944 [“T”,“C”] 1.00e+00 3.33e-01
2:29350 [“A”,“G”] 1.00e+00 0.00e+00

I’m interested in also getting the original genotypes rather than just the fraction values. I guess I would have to index the loci presented in mt2 and index mt. I was wondering if there’s a way to do it with just mt.

Hey @syaffers !

Instead of using group_cols_by you can use a grouped aggregator:

mt = mt.annotate_cols(
    is_control=hl.literal(set(controls)).contains(mt.s)
)
mt = mt.annotate_rows(
    fraction_hom_ref_by_is_control = hl.agg.group_by(
        mt.is_control,
        mt.GT.is_hom_ref()
    )
)
mt = mt.filter_rows(
    (mt.fraction_hom_ref_by_is_control[True] > 0.5) & (
        mt.fraction_hom_ref_by_is_control[False] < 0.5
    )
)
mt.show()

It might be easier to read this code if you used hl.agg.filter instead of hl.agg.group_by:

mt = mt.annotate_cols(
    is_control=hl.literal(set(controls)).contains(mt.s)
)
mt = mt.annotate_rows(
    control_hom_rate = hl.agg.filter(
        mt.is_control,
        hl.agg.fraction(mt.GT.is_hom_ref())),
    case_hom_rate = hl.agg.filter(
        ~mt.is_control,
        hl.agg.fraction(mt.GT.is_hom_ref()))
)
mt = mt.filter_rows(
    (mt.control_hom_rate > 0.5) & (mt.case_hom_rate < 0.5)
)
mt.show()