De novo calls on hemizygous X variants

I’ve continued to work on this and ought to post another follow-up… here is the version of the caller I’m currently using after more conversation with Dave Cutler and Mark Daly. It no longer is the same as hl.de_novo on the autosome, since I changed how the probability of a parental alt allele is calculated, added terms to account for the possibility that the kid doesn’t actually have a variant, and now have different de novo priors for mom vs dad. It also uses my own function to get AC and AN values (I’ll post that below). In practice this still gives pretty similar output to hl.de_novo, with the addition of calls in hemizygous regions.

# THIS REQUIRES A DATASET WITH mt.AC AND mt.AN VALUES
@typecheck(mt=MatrixTable,
           pedigree=Pedigree,
           pop_frequency_prior=expr_float64,
           min_gq=int,
           min_p=numeric,
           max_parent_ab=numeric,
           min_child_ab=numeric,
           min_dp_ratio=numeric,
           ignore_in_sample_allele_frequency=bool)

def my_de_novo_v16(mt: MatrixTable,
            pedigree: Pedigree,
            pop_frequency_prior,
            *,
            min_gq: int = 20,
            min_p: float = 0.05,
            max_parent_ab: float = 0.05,
            min_child_ab: float = 0.20,
            min_dp_ratio: float = 0.10,
            ignore_in_sample_allele_frequency: bool = False) -> Table:
  
    MOM_DE_NOVO_PRIOR = 1 / 100000000
    DAD_DE_NOVO_PRIOR = 1 / 25000000
    MITO_DE_NOVO_PRIOR = 10 * MOM_DE_NOVO_PRIOR # Rough estimate, could use review for mitochondrial applications
    
    MIN_POP_PRIOR = 100 / 30000000

    required_entry_fields = {'GT', 'AD', 'DP', 'GQ', 'PL'}
    missing_fields = required_entry_fields - set(mt.entry)
    if missing_fields:
        raise ValueError(f"'de_novo': expected 'MatrixTable' to have at least {required_entry_fields}, "
                         f"missing {missing_fields}")

    pop_frequency_prior = hl.case() \
        .when((pop_frequency_prior >= 0) & (pop_frequency_prior <= 1), pop_frequency_prior) \
        .or_error(hl.str("de_novo: expect 0 <= pop_frequency_prior <= 1, found " + hl.str(pop_frequency_prior)))

    if ignore_in_sample_allele_frequency:
        # this mode is used when families larger than a single trio are observed, in which
        # an allele might be de novo in a parent and transmitted to a child in the dataset.
        # The original model does not handle this case correctly, and so this experimental
        # mode can be used to treat each trio as if it were the only one in the dataset.
        mt = mt.annotate_rows(__prior=pop_frequency_prior,
                              __alt_alleles=hl.int64(1),
                              __site_freq=hl.max(pop_frequency_prior, MIN_POP_PRIOR))
    else:
        # Note that this now requires prior annotation with AC and AN values
        # (did this August 2022 to correct for hemizygous situations)
        n_alt_alleles = mt.AC
        total_alleles = mt.AN
        # subtract 1 from __alt_alleles to correct for the observed genotype
        mt = mt.annotate_rows(__prior=pop_frequency_prior,
                              __alt_alleles=n_alt_alleles,
                              __site_freq=hl.max((n_alt_alleles - 1) / total_alleles,
                                                 pop_frequency_prior,
                                                 MIN_POP_PRIOR))

    mt = require_biallelic(mt, 'de_novo')

    tm = hl.trio_matrix(mt, pedigree, complete_trios=True)
    tm = tm.annotate_rows(__autosomal=tm.locus.in_autosome_or_par(),
                          __x_nonpar=tm.locus.in_x_nonpar(),
                          __y_nonpar=tm.locus.in_y_nonpar(),
                          __mito=tm.locus.in_mito())
    
    autosomal = tm.__autosomal | (tm.__x_nonpar & tm.is_female)
    hemi_x = tm.__x_nonpar & ~tm.is_female
    hemi_y = tm.__y_nonpar & ~tm.is_female
    hemi_mt = tm.__mito

    is_snp = hl.is_snp(tm.alleles[0], tm.alleles[1])
    n_alt_alleles = tm.__alt_alleles
    prior = tm.__site_freq
    
    # Updated for hemizygous child calls to not require a call in uninvolved parent
    has_candidate_gt_configuration = (
        ( autosomal & tm.proband_entry.GT.is_het() & 
          tm.father_entry.GT.is_hom_ref() & tm.mother_entry.GT.is_hom_ref() ) |
        ( hemi_x & tm.proband_entry.GT.is_hom_var() & tm.mother_entry.GT.is_hom_ref() ) |
        ( hemi_y & tm.proband_entry.GT.is_hom_var() & tm.father_entry.GT.is_hom_ref() ) |
        ( hemi_mt & tm.proband_entry.GT.is_hom_var() & tm.mother_entry.GT.is_hom_ref() ) )
    
    failure = hl.missing(hl.tstruct(p_de_novo=hl.tfloat64, confidence=hl.tstr))
    
    kid = tm.proband_entry
    dad = tm.father_entry
    mom = tm.mother_entry

    kid_linear_pl = 10 ** (-kid.PL / 10)
    kid_pp = hl.bind(lambda x: x / hl.sum(x), kid_linear_pl)

    dad_linear_pl = 10 ** (-dad.PL / 10)
    dad_pp = hl.bind(lambda x: x / hl.sum(x), dad_linear_pl)

    mom_linear_pl = 10 ** (-mom.PL / 10)
    mom_pp = hl.bind(lambda x: x / hl.sum(x), mom_linear_pl)

    kid_ad_ratio = kid.AD[1] / hl.sum(kid.AD)
    
    # Try to get these all to an expected value of 0.5
    dp_ratio = (hl.case()
                  .when(hemi_x, kid.DP / mom.DP) # Because mom is diploid but kid is not
                  .when(hemi_y, kid.DP / (2*dad.DP))
                  .when(hemi_mt, kid.DP / (2*mom.DP))
                  .when(tm.__x_nonpar & tm.is_female, (kid.DP / (mom.DP + dad.DP)) * (3/4)) # Because of hemi dad
                  .default( kid.DP / (mom.DP + dad.DP) ))

    def solve(p_de_novo, parent_tot_read_check, parent_alt_read_check):
        return (
            hl.case()
            .when(kid.GQ < min_gq, failure)
            .when((dp_ratio < min_dp_ratio)
                  | (kid_ad_ratio < min_child_ab), failure)
            .when(~parent_tot_read_check, failure)
            .when(~parent_alt_read_check, failure)
            .when(p_de_novo < min_p, failure)
            .when(~is_snp, hl.case()
                  .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1),
                        hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                  .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5),
                        hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                  .when(kid_ad_ratio > 0.2,
                        hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                  .or_missing())
            .default(hl.case()
                     .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2))
                           | ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1))
                           | ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)),
                           hl.struct(p_de_novo=p_de_novo, confidence='HIGH'))
                     .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)),
                           hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'))
                     .when(kid_ad_ratio > 0.2,
                           hl.struct(p_de_novo=p_de_novo, confidence='LOW'))
                     .or_missing()))
    
    # Call autosomal or pseudoautosomal de novo variants
    def call_auto(kid_pp, dad_pp, mom_pp):
        p_data_given_dn_from_mom = mom_pp[0] * dad_pp[0] * kid_pp[1]
        p_dn_from_mom = ((1 - prior) ** 4) * MOM_DE_NOVO_PRIOR * (1 - DAD_DE_NOVO_PRIOR)
        
        p_data_given_dn_from_dad = mom_pp[0] * dad_pp[0] * kid_pp[1]
        p_dn_from_dad = ((1 - prior) ** 4) * (1 - MOM_DE_NOVO_PRIOR) * DAD_DE_NOVO_PRIOR
        
        p_data_given_mom_alt_allele = mom_pp[1] * dad_pp[0] * kid_pp[1]
        p_mom_alt_allele = prior * ((1 - prior) ** 3) * (1 - MOM_DE_NOVO_PRIOR) * (1 - DAD_DE_NOVO_PRIOR)

        p_data_given_dad_alt_allele = mom_pp[0] * dad_pp[1] * kid_pp[1]
        p_dad_alt_allele = prior * ((1 - prior) ** 3) * (1 - MOM_DE_NOVO_PRIOR) * (1 - DAD_DE_NOVO_PRIOR)

        p_data_given_no_alt_alleles = mom_pp[0] * dad_pp[0] * kid_pp[0]
        p_no_alt_alleles = ((1 - prior) ** 4) * (1 - MOM_DE_NOVO_PRIOR) * (1 - DAD_DE_NOVO_PRIOR)
                               
        p_de_novo = ( (p_data_given_dn_from_mom * p_dn_from_mom) + (p_data_given_dn_from_dad * p_dn_from_dad) ) / (
                      (p_data_given_dn_from_mom * p_dn_from_mom) + (p_data_given_dn_from_dad * p_dn_from_dad) +
                      (p_data_given_mom_alt_allele * p_mom_alt_allele) + (p_data_given_dad_alt_allele * p_dad_alt_allele) +
                      (p_data_given_no_alt_alleles * p_no_alt_alleles) )

        parent_tot_read_check = (hl.sum(mom.AD) > 0) & (hl.sum(dad.AD) > 0)
        parent_alt_read_check = ( ((mom.AD[1] / hl.sum(mom.AD)) <= max_parent_ab) &
                                  ((dad.AD[1] / hl.sum(dad.AD)) <= max_parent_ab) )
            
        return hl.bind(solve, p_de_novo, parent_tot_read_check, parent_alt_read_check)

    # Call hemizygous de novo variants on the X
    def call_hemi_x(kid_pp, mom_pp):
        p_data_given_dn_from_mom = mom_pp[0] * kid_pp[2]
        p_dn_from_mom = ((1 - prior) ** 2) * MOM_DE_NOVO_PRIOR

        p_data_given_mom_alt_allele = mom_pp[1] * kid_pp[2]
        p_mom_alt_allele = prior * (1 - prior) * (1 - MOM_DE_NOVO_PRIOR)

        p_data_given_no_alt_alleles = mom_pp[0] * kid_pp[0]
        p_no_alt_alleles = ((1 - prior) ** 2) * (1 - MOM_DE_NOVO_PRIOR)

        p_de_novo = ( (p_data_given_dn_from_mom * p_dn_from_mom) ) / (
                      (p_data_given_dn_from_mom * p_dn_from_mom) +
                      (p_data_given_mom_alt_allele * p_mom_alt_allele) +
                      (p_data_given_no_alt_alleles * p_no_alt_alleles) )

        parent_tot_read_check = (hl.sum(mom.AD) > 0)
        parent_alt_read_check = (mom.AD[1] / hl.sum(mom.AD)) <= max_parent_ab      

        return hl.bind(solve, p_de_novo, parent_tot_read_check, parent_alt_read_check)

    # Call hemizygous de novo variants on the Y
    def call_hemi_y(kid_pp, dad_pp):
        p_data_given_dn_from_dad = dad_pp[0] * kid_pp[2]
        p_dn_from_dad = (1 - prior) * DAD_DE_NOVO_PRIOR

        p_data_given_dad_alt_allele = dad_pp[2] * kid_pp[2]
        p_dad_alt_allele = prior * (1 - DAD_DE_NOVO_PRIOR)

        p_data_given_no_alt_alleles = dad_pp[0] * kid_pp[0]
        p_no_alt_alleles = (1 - prior) * (1 - DAD_DE_NOVO_PRIOR)
        
        p_de_novo = ( (p_data_given_dn_from_dad * p_dn_from_dad) ) / (
                      (p_data_given_dn_from_dad * p_dn_from_dad) +
                      (p_data_given_dad_alt_allele * p_dad_alt_allele) +
                      (p_data_given_no_alt_alleles * p_no_alt_alleles) )

        parent_tot_read_check = (hl.sum(dad.AD) > 0)
        parent_alt_read_check = (dad.AD[1] / hl.sum(dad.AD)) <= max_parent_ab      

        return hl.bind(solve, p_de_novo, parent_tot_read_check, parent_alt_read_check)       
    
    # Call de novo variants on the mitochondrial chromosome 
    # (treating it as a hemizygous situation and ignoring mito heteroplasmy)
    def call_hemi_mt(kid_pp, mom_pp):
        p_data_given_dn_from_mom = mom_pp[0] * kid_pp[2]
        p_dn_from_mom = (1 - prior) * MITO_DE_NOVO_PRIOR

        p_data_given_mom_alt_allele = mom_pp[2] * kid_pp[2]
        p_mom_alt_allele = prior * (1 - MITO_DE_NOVO_PRIOR)

        p_data_given_no_alt_alleles = mom_pp[0] * kid_pp[0]
        p_no_alt_alleles = (1 - prior) * (1 - MITO_DE_NOVO_PRIOR)
        
        p_de_novo = ( (p_data_given_dn_from_mom * p_dn_from_mom) ) / (
                      (p_data_given_dn_from_mom * p_dn_from_mom) +
                      (p_data_given_mom_alt_allele * p_mom_alt_allele) +
                      (p_data_given_no_alt_alleles * p_no_alt_alleles) )

        parent_tot_read_check = (hl.sum(mom.AD) > 0)
        parent_alt_read_check = (mom.AD[1] / hl.sum(mom.AD)) <= max_parent_ab      

        return hl.bind(solve, p_de_novo, parent_tot_read_check, parent_alt_read_check)       

    
    # Main routine
    de_novo_call = (
        hl.case()
        .when(~has_candidate_gt_configuration, failure)
        .when(autosomal, hl.bind(call_auto, kid_pp, dad_pp, mom_pp))
        .when(hemi_x, hl.bind(call_hemi_x, kid_pp, mom_pp))
        .when(hemi_y, hl.bind(call_hemi_y, kid_pp, dad_pp))
        .when(hemi_mt, hl.bind(call_hemi_mt, kid_pp, mom_pp))
        .or_missing())

    tm = tm.annotate_entries(__call=de_novo_call)
    tm = tm.filter_entries(hl.is_defined(tm.__call))
    entries = tm.entries()
    
    return (entries.select('__site_freq',
                           'proband',
                           'father',
                           'mother',
                           'proband_entry',
                           'father_entry',
                           'mother_entry',
                           'is_female',
                           **entries.__call)
            .rename({'__site_freq': 'prior'}))