Simple rare variant burden testing with Fisher exact test

Hail has a general way of performing burden tests within regression framework.
Logistic regression burden tests.

However in rare variant analysis the counts are often so low that some regression methods might not even converge and Fisher’s exact test can be used. I wrote a simple Python function that does just that. Given filtered vds, grouping variant annotation and a list of binary phenotypes it does a case control burden test for each phenotype and returns a keytable with all comparisons as well as carrier and variant id’s.

def do_burden(vds, pheno_list,group_by_anno, pheno_root="sa.pheno"):
     Does a Fisher's burden test counting each individual only once for a variant for each element of group_by_anno.
     For doing e.g. LoF filter  first to lofs and call the function with filtered VDS.
    :param vds:
    :param pheno_list:  list of binary phenotypes
    :param group_by_anno: variant annotation to group by e.g. va.gene
    :return: keytable with all comparison columns and carrier and variant id's
    res = []
    for pheno in pheno_list:
        p = pheno.split(".")
        p = p[ -1 ]
        g_kt = vds.aggregate_by_key( 'group=' +  group_by_anno + ', sample=s, pheno=' + pheno_root + "." + pheno,
            'carrier= g =>  if( g.isCalledNonRef ) 1 else 0 ).collect().max, vars= if (g.filter( g=> g.isCalledNonRef).count()>0)> s).collect()[0] + "[" + g.filter( g=> g.isCalledNonRef).map( g=> v.contig + ":" + v.start + ":" +  v.ref + ":" + v.alt  ).collect().mkString(",") + "]" else NA:String   ')
        case_carriers = g_kt.filter('pheno').aggregate_by_key(' group=group',
                                                              p +'_case_carriers = carrier.collect().sum, ' + p +'_case_n=carrier.count(),' + p +'_case_vars= vars.filter( v=> !isMissing(v) ).collect().mkString(";")' )
        control_carriers = g_kt.filter("!pheno").aggregate_by_key(' group=group',
                                                                  p +'_control_carriers = carrier.collect().sum, ' + p +'_control_n=carrier.count(),' + p + '_control_vars= vars.filter( v=> !isMissing(v) ).collect().mkString(";")')
        case_carriers =["group", p + "_case_carriers", p +"_case_n",  p +"_case_vars"]).key_by(["group"])
        control_carriers =["group",  p + "_control_carriers", p + "_control_n",p +"_control_vars"]).key_by(["group"])
        burdens = case_carriers.join(control_carriers, how="outer")
        b = burdens.annotate( p + '_assoc=fet( ' + p + '_case_carriers.toInt, ('+ p +'_case_n - '+ p +'_case_carriers).toInt,'+ p +'_control_carriers.toInt, ('+ p +'_control_n - '+ p +'_control_carriers).toInt) ')
        res.append( b)

    if len(pheno_list)> 1:
        res = reduce( lambda x, y:  x.join(y, how="outer"), res )
        res = res[0]

    return res.flatten()