Scans running out of memory

Hi hail team,

I’m running this code on the full exome:

def calculate_exp_per_base(
    context_ht: hl.Table,
    groupings: List[str] = [
        "context",
        "ref",
        "alt",
        "cpg",
        "methylation_level",
        "mu_snp",
        "transcript",
        "exome_coverage",
    ],
) -> hl.Table:
    """
    Returns table with expected variant counts annotated per base. 

    Expected variants count is mutation rate per SNP adjusted by location in the genome/CpG status (plateau model) and coverage (coverage model).

    .. note::
        Expects:
        - context_ht is annotated with all of the fields in `groupings` and that the names match exactly.
            That means the HT should have context, ref, alt, CpG status, methylation level, mutation rate (`mu_snp`), 
            transcript, and coverage (`exome_coverage`) if using the default value for `groupings`.
        - context_ht contains coverage and plateau models in its global annotations (`coverage_model`, `plateau_models`).

    :param hl.Table context_ht: Context Table.
    :param List[str] groupings: List of Table fields used to group Table to adjust mutation rate. 
        Table must be annotated with these fields. Default fields are context, ref, alt, cpg, methylation level, mu_snp, transcript, and exome_coverage.
    :return: Table grouped by transcript with expected variant counts per transcript.
    :rtype: hl.Table
    """
    logger.info(f"Annotating HT with groupings: {groupings}...")
    context_ht = context_ht.annotate(variant=(context_ht.row.select(*groupings)))

    logger.info("Getting cumulative aggregated mutation rate per variant...")
    context_ht = context_ht.annotate(
        mu_agg=hl.scan.group_by(
            context_ht.variant,
            hl.struct(
                # Use scan sum here because this annotation stores the cumulative mutation rate for
                # each context/ref/alt/methylation level
                mu_agg=hl.scan.sum(context_ht.variant.mu_snp),
                # Need _prev_nonnull to get the correct cpg and exome coverage for each scanned variant
                # Without this, the scan pulls the values from the current line, NOT the previous line
                cpg=hl.scan._prev_nonnull(context_ht.variant.cpg),
                coverage_correction=get_coverage_correction_expr(
                    hl.scan._prev_nonnull(context_ht.variant.exome_coverage),
                    context_ht.coverage_model,
                ),
            ),
        )
    )

    logger.info("Adjusting mutation rate using plateau and coverage models...")
    model = get_plateau_model(
        context_ht.locus, context_ht.cpg, context_ht.globals, include_cpg=True
    )
    context_ht = context_ht.annotate(
        mu=context_ht.mu_agg.map_values(lambda x: (x.mu_agg) * x.coverage_correction),
        all_exp=context_ht.mu_agg.map_values(
            lambda x: (x.mu_agg * model[x.cpg][1] + model[x.cpg][0])
            * x.coverage_correction
        ),
    )

    logger.info("Aggregating proportion of expected variants per site and returning...")
    context_ht = context_ht.annotate(
        transcript_exp_keys=context_ht.all_exp.keys().filter(
            lambda x: x.transcript == context_ht.transcript
        )
    )
    context_ht = context_ht.annotate(
        transcript_exp=hl.map(
            lambda x: context_ht.all_exp.get(x), context_ht.transcript_exp_keys
        )
    )
    return context_ht.annotate(cumulative_exp=hl.sum(context_ht.transcript_exp)).select(
        "cumulative_exp", "mu"
    )

and am running into this error:

# There is insufficient memory for the Java Runtime Environment to continue.
# Native memory allocation (mmap) failed to map 2729967616 bytes for committing reserved memory.
# An error report file with more information is saved as:
# /tmp/08573378253f479a862fbc3119ba0d6d/hs_err_pid13698.log

Unfortunately, the log didn’t copy to a separate bucket before the cluster shut down. I was using 120 non-preemptible highmem 8 workers (boot disk size: 100). Do you have any suggestions for how to fix this? I’m trying to run this before the conference next week

For some context (heh), this is a huge job - it’s a scan down 9B rows, though sparse - “only” 180M rows have any data in them. The group_by variant (not actually every variant but context/ref/alt/etc which can have about ~100 keys per transcript, and obviously there’s lots of transcripts) will probably result in lots of keys. My current suspicion is that by the bottom of the table, you’ve got lots of data in this scan, most of which is not used anymore - once we’re out of a given transcript, it would be ideal to “clear” the data in this but obviously that’s not possible here. Any ideas how to optimize this?