Exception in thread "Spark Context Cleaner" java.lang.OutOfMemoryError: Java heap space

Hey guys,

I have been running the “Gnomad gene models pipeline”

Combine GTF files, HGNC data, and canonical transcripts lists.

hailctl dataproc submit data-prep
–pyfiles ./data/data_utils
./data/prepare_gene_models.py
–gencode 29
/path/to/gencode.v29.gtf.bgz
$CANONICAL_TRANSCRIPTS_GRCH38_PATH
–gencode 19
/path/to/gencode.v19.gtf.bgz
$CANONICAL_TRANSCRIPTS_GRCH37_PATH
–hgnc /path/to/hgnc.tsv
–mane-select-transcripts /path/to/mane_summary.tsv.gz
–output /path/to/genes.ht

I my job cannot finish as I keep getting error as below: I am not sure how do I fix this.
"
Exception in thread “RemoteBlock-temp-file-clean-thread” java.lang.OutOfMemoryError: Java heap space
Exception in thread “Spark Context Cleaner” java.lang.OutOfMemoryError: Java heap space
Traceback (most recent call last):
File “/home/mamta/Test/gnomad-browser/data/prepare_gene_models.py”, line 348, in
main()
File “/home/mamta/Test/gnomad-browser/data/prepare_gene_models.py”, line 213, in main
gencode_genes = load_gencode_gene_models(gtf_path, min_partitions=args.min_partitions)
File “/home/mamta/Test/gnomad-browser/data/prepare_gene_models.py”, line 187, in load_gencode_gene_models
genes = genes.cache()
File “/home/mamta/.local/lib/python3.6/site-packages/hail/table.py”, line 1742, in cache
return self.persist(‘MEMORY_ONLY’)
File “”, line 2, in persist
File “/home/mamta/.local/lib/python3.6/site-packages/hail/typecheck/check.py”, line 585, in wrapper
return original_func(*args, **kwargs)
File “/home/mamta/.local/lib/python3.6/site-packages/hail/table.py”, line 1780, in persist
return Env.backend().persist_table(self, storage_level)
File “/home/mamta/.local/lib/python3.6/site-packages/hail/backend/backend.py”, line 227, in persist_table
return Table._from_java(self._to_java_ir(t._tir).pyPersist(storage_level))
File “/home/mamta/.local/lib/python3.6/site-packages/py4j/java_gateway.py”, line 1257, in call
answer, self.gateway_client, self.target_id, self.name)
File “/home/mamta/.local/lib/python3.6/site-packages/hail/utils/java.py”, line 211, in deco
‘Error summary: %s’ % (deepest, full, hail.version, deepest)) from None
hail.utils.java.FatalError: SparkException: Job aborted due to stage failure: Task 9 in stage 4.0 failed 1 times, most recent failure: Lost task 9.0 in stage 4.0 (TID 137, localhost, executor driver): ExecutorLostFailure (executor driver exited caused by one of the running tasks) Reason: Executor heartbeat timed out after 137617 ms
Driver stacktrace:

Java stack trace:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 9 in stage 4.0 failed 1 times, most recent failure: Lost task 9.0 in stage 4.0 (TID 137, localhost, executor driver): ExecutorLostFailure (executor driver exited caused by one of the running tasks) Reason: Executor heartbeat timed out after 137617 ms
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2114)
at is.hail.sparkextras.ContextRDD.crunJobWithIndex(ContextRDD.scala:228)
at is.hail.rvd.RVD$.getKeyInfo(RVD.scala:1226)
at is.hail.rvd.RVD$.makeCoercer(RVD.scala:1301)
at is.hail.rvd.RVD$.coerce(RVD.scala:1256)
at is.hail.rvd.RVD$.coerce(RVD.scala:1240)
at is.hail.expr.ir.TableKeyByAndAggregate.execute(TableIR.scala:1733)
at is.hail.expr.ir.TableLeftJoinRightDistinct.execute(TableIR.scala:1062)
at is.hail.expr.ir.TableMapRows.execute(TableIR.scala:1088)
at is.hail.expr.ir.TableLeftJoinRightDistinct.execute(TableIR.scala:1061)
at is.hail.expr.ir.TableMapRows.execute(TableIR.scala:1088)
at is.hail.expr.ir.Interpret$.apply(Interpret.scala:23)
at is.hail.expr.ir.TableIR$$anonfun$persist$1.apply(TableIR.scala:53)
at is.hail.expr.ir.TableIR$$anonfun$persist$1.apply(TableIR.scala:52)
at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:15)
at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:13)
at is.hail.utils.package$.using(package.scala:604)
at is.hail.annotations.Region$.scoped(Region.scala:18)
at is.hail.expr.ir.ExecuteContext$.scoped(ExecuteContext.scala:13)
at is.hail.expr.ir.ExecuteContext$.scoped(ExecuteContext.scala:10)
at is.hail.expr.ir.TableIR.persist(TableIR.scala:52)
at is.hail.expr.ir.TableIR.pyPersist(TableIR.scala:72)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)

Hail version: 0.2.37-7952b436bd70

This is a rather old version – could you try updating to latest Hail and running again?

I did update to 0.2.45 but still the same

"[Stage 0:=====================================================> (30 + 2) / 32]2020-06-30 13:14:57 Hail: INFO: wrote table with 2619444 rows in 32 partitions to /tmp/s6rmbQZkajmoEMpZG8BGZ0
[Stage 2:========================================> (23 + 9) / 32]2020-06-30 13:15:01 Hail: INFO: Ordering unsorted dataset with network shuffle
2020-06-30 13:15:03 Hail: INFO: Ordering unsorted dataset with network shuffle
[Stage 5:> (0 + 32) / 32]Traceback (most recent call last):
File “/home/mamta/Test/gnomad-browser/data/prepare_gene_models.py”, line 271, in
main()
File “/home/mamta/Test/gnomad-browser/data/prepare_gene_models.py”, line 159, in main
gencode_genes = load_gencode_gene_models(gtf_path, min_partitions=args.min_partitions)
File “/home/mamta/Test/gnomad-browser/data/prepare_gene_models.py”, line 137, in load_gencode_gene_models
genes = genes.cache()
File “/home/mamta/miniconda3/lib/python3.7/site-packages/hail/table.py”, line 1798, in cache
return self.persist(‘MEMORY_ONLY’)
File “”, line 2, in persist
File “/home/mamta/miniconda3/lib/python3.7/site-packages/hail/typecheck/check.py”, line 614, in wrapper
return original_func(*args, **kwargs)
File “/home/mamta/miniconda3/lib/python3.7/site-packages/hail/table.py”, line 1836, in persist
return Env.backend().persist_table(self, storage_level)
File “/home/mamta/miniconda3/lib/python3.7/site-packages/hail/backend/spark_backend.py”, line 315, in persist_table
return Table._from_java(self._jbackend.pyPersistTable(storage_level, self._to_java_table_ir(t._tir)))
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 1257, in call
answer, self.gateway_client, self.target_id, self.name)
File “/home/mamta/miniconda3/lib/python3.7/site-packages/hail/backend/spark_backend.py”, line 41, in deco
‘Error summary: %s’ % (deepest, full, hail.version, deepest)) from None
hail.utils.java.FatalError: SparkException: Job aborted due to stage failure: Task 31 in stage 5.0 failed 1 times, most recent failure: Lost task 31.0 in stage 5.0 (TID 191, localhost, executor driver): ExecutorLostFailure (executor driver exited caused by one of the running tasks) Reason: Executor heartbeat timed out after 136076 ms
Driver stacktrace:

Java stack trace:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 31 in stage 5.0 failed 1 times, most recent failure: Lost task 31.0 in stage 5.0 (TID 191, localhost, executor driver): ExecutorLostFailure (executor driver exited caused by one of the running tasks) Reason: Executor heartbeat timed out after 136076 ms
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2114)
at is.hail.sparkextras.ContextRDD.crunJobWithIndex(ContextRDD.scala:228)
at is.hail.rvd.RVD$.getKeyInfo(RVD.scala:1272)
at is.hail.rvd.RVD$.makeCoercer(RVD.scala:1347)
at is.hail.rvd.RVD$.coerce(RVD.scala:1303)
at is.hail.rvd.RVD$.coerce(RVD.scala:1287)
at is.hail.expr.ir.TableKeyByAndAggregate.execute(TableIR.scala:2069)
at is.hail.expr.ir.TableLeftJoinRightDistinct.execute(TableIR.scala:1406)
at is.hail.expr.ir.TableMapRows.execute(TableIR.scala:1432)
at is.hail.expr.ir.TableLeftJoinRightDistinct.execute(TableIR.scala:1405)
at is.hail.expr.ir.TableMapRows.execute(TableIR.scala:1432)
at is.hail.expr.ir.Interpret$.apply(Interpret.scala:23)
at is.hail.backend.spark.SparkBackend$$anonfun$pyPersistTable$1.apply(SparkBackend.scala:402)
at is.hail.backend.spark.SparkBackend$$anonfun$pyPersistTable$1.apply(SparkBackend.scala:401)
at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:20)
at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:18)
at is.hail.utils.package$.using(package.scala:601)
at is.hail.annotations.Region$.scoped(Region.scala:18)
at is.hail.expr.ir.ExecuteContext$.scoped(ExecuteContext.scala:18)
at is.hail.backend.spark.SparkBackend.withExecuteContext(SparkBackend.scala:229)
at is.hail.backend.spark.SparkBackend.pyPersistTable(SparkBackend.scala:401)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)

Hail version: 0.2.45-a45a43f21e83
Error summary: SparkException: Job aborted due to stage failure: Task 31 in stage 5.0 failed 1 times, most recent failure: Lost task 31.0 in stage 5.0 (TID 191, localhost, executor driver): ExecutorLostFailure (executor driver exited caused by one of the running tasks) Reason: Executor heartbeat timed out after 136076 ms
Driver stacktrace:
ERROR:root:Exception while sending command.
Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 1159, in send_command
raise Py4JNetworkError(“Answer from Java side is empty”)
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 985, in send_command
response = connection.send_command(command)
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 1164, in send_command
“Error while receiving”, e, proto.ERROR_ON_RECEIVE)
py4j.protocol.Py4JNetworkError: Error while receiving
ERROR:py4j.java_gateway:An error occurred while trying to connect to the Java server (127.0.0.1:46285)
Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 929, in _get_connection
connection = self.deque.pop()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 1067, in start
self.socket.connect((self.address, self.port))
ConnectionRefusedError: [Errno 111] Connection refused
ERROR:py4j.java_gateway:An error occurred while trying to connect to the Java server (127.0.0.1:46285)
Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 929, in _get_connection
connection = self.deque.pop()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 1067, in start
self.socket.connect((self.address, self.port))
ConnectionRefusedError: [Errno 111] Connection refused
ERROR:py4j.java_gateway:An error occurred while trying to connect to the Java server (127.0.0.1:46285)
Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 929, in _get_connection
connection = self.deque.pop()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 1067, in start
self.socket.connect((self.address, self.port))
ConnectionRefusedError: [Errno 111] Connection refused
ERROR:py4j.java_gateway:An error occurred while trying to connect to the Java server (127.0.0.1:46285)
Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 929, in _get_connection
connection = self.deque.pop()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File “/home/mamta/miniconda3/lib/python3.7/site-packages/py4j/java_gateway.py”, line 1067, in start
self.socket.connect((self.address, self.port))
ConnectionRefusedError: [Errno 111] Connection refused

what’s that script?

The script through gnomad:

import argparse

import hail as hl

from data_utils.regions import merge_overlapping_regions


def xpos(contig_str, position):
    contig_number = (
        hl.case()
        .when(contig_str == "X", 23)
        .when(contig_str == "Y", 24)
        .when(contig_str[0] == "M", 25)
        .default(hl.int(contig_str))
    )

    return hl.int64(contig_number) * 1_000_000_000 + position


###############################################
# Exons                                       #
###############################################


def get_exons(gencode):
    """
    Filter Gencode table to exons and format fields.
    """
    exons = gencode.filter(hl.set(["exon", "CDS", "UTR"]).contains(gencode.feature))
    exons = exons.select(
        feature_type=exons.feature,
        transcript_id=exons.transcript_id.split("\\.")[0],
        gene_id=exons.gene_id.split("\\.")[0],
        chrom=exons.interval.start.seqname[3:],
        strand=exons.strand,
        start=exons.interval.start.position,
        stop=exons.interval.end.position,
    )

    return exons


###############################################
# Genes                                       #
###############################################


def get_genes(gencode):
    """
    Filter Gencode table to genes and format fields.
    """
    genes = gencode.filter(gencode.feature == "gene")
    genes = genes.select(
        gene_id=genes.gene_id.split("\\.")[0],
        gene_version=genes.gene_id.split("\\.")[1],
        gene_symbol=genes.gene_name,
        chrom=genes.interval.start.seqname[3:],
        strand=genes.strand,
        start=genes.interval.start.position,
        stop=genes.interval.end.position,
    )

    genes = genes.annotate(xstart=xpos(genes.chrom, genes.start), xstop=xpos(genes.chrom, genes.stop))

    genes = genes.key_by(genes.gene_id).drop("interval")

    return genes


def collect_gene_exons(gene_exons):
    # There are 3 feature types in the exons collection: "CDS", "UTR", and "exon".
    # There are "exon" regions that cover the "CDS" and "UTR" regions and also
    # some (non-coding) transcripts that contain only "exon" regions.
    # This filters the "exon" regions to only those that are in non-coding transcripts.
    #
    # This makes the UI for selecting visible regions easier, since it can filter
    # on "CDS" or "UTR" feature type without having to also filter out the "exon" regions
    # that duplicate the "CDS" and "UTR" regions.

    non_coding_transcript_exons = hl.bind(
        lambda coding_transcripts: gene_exons.filter(lambda exon: ~coding_transcripts.contains(exon.transcript_id)),
        hl.set(
            gene_exons.filter(lambda exon: (exon.feature_type == "CDS") | (exon.feature_type == "UTR")).map(
                lambda exon: exon.transcript_id
            )
        ),
    )

    exons = (
        merge_overlapping_regions(gene_exons.filter(lambda exon: exon.feature_type == "CDS"))
        .extend(merge_overlapping_regions(gene_exons.filter(lambda exon: exon.feature_type == "UTR")))
        .extend(merge_overlapping_regions(non_coding_transcript_exons))
    )

    exons = exons.map(
        lambda exon: exon.select(
            "feature_type", "start", "stop", xstart=xpos(exon.chrom, exon.start), xstop=xpos(exon.chrom, exon.stop)
        )
    )

    return exons


###############################################
# Transcripts                                 #
###############################################


def get_transcripts(gencode):
    """
    Filter Gencode table to transcripts and format fields.
    """
    transcripts = gencode.filter(gencode.feature == "transcript")
    transcripts = transcripts.select(
        transcript_id=transcripts.transcript_id.split("\\.")[0],
        transcript_version=transcripts.transcript_id.split("\\.")[1],
        gene_id=transcripts.gene_id.split("\\.")[0],
        chrom=transcripts.interval.start.seqname[3:],
        strand=transcripts.strand,
        start=transcripts.interval.start.position,
        stop=transcripts.interval.end.position,
    )

    transcripts = transcripts.annotate(
        xstart=xpos(transcripts.chrom, transcripts.start), xstop=xpos(transcripts.chrom, transcripts.stop)
    )

    transcripts = transcripts.key_by(transcripts.transcript_id).drop("interval")

    return transcripts


def collect_transcript_exons(transcript_exons):
    # There are 3 feature types in the exons collection: "CDS", "UTR", and "exon".
    # There are "exon" regions that cover the "CDS" and "UTR" regions and also
    # some (non-coding) transcripts that contain only "exon" regions.
    # This filters the "exon" regions to only those that are in non-coding transcripts.
    #
    # This makes the UI for selecting visible regions easier, since it can filter
    # on "CDS" or "UTR" feature type without having to also filter out the "exon" regions
    # that duplicate the "CDS" and "UTR" regions.

    is_coding = transcript_exons.any(lambda exon: (exon.feature_type == "CDS") | (exon.feature_type == "UTR"))

    exons = hl.cond(is_coding, transcript_exons.filter(lambda exon: exon.feature_type != "exon"), transcript_exons)

    exons = exons.map(
        lambda exon: exon.select(
            "feature_type", "start", "stop", xstart=xpos(exon.chrom, exon.start), xstop=xpos(exon.chrom, exon.stop)
        )
    )

    return exons


###############################################
# Main                                        #
###############################################


def load_gencode_gene_models(gtf_path, min_partitions=32):
    gencode = hl.experimental.import_gtf(gtf_path, min_partitions=min_partitions)

    # Extract genes and transcripts
    genes = get_genes(gencode)
    transcripts = get_transcripts(gencode)

    # Annotate genes/transcripts with their exons
    exons = get_exons(gencode)
    exons = exons.cache()

    gene_exons = exons.group_by(exons.gene_id).aggregate(exons=hl.agg.collect(exons.row_value))
    genes = genes.annotate(exons=collect_gene_exons(gene_exons[genes.gene_id].exons))

    transcript_exons = exons.group_by(exons.transcript_id).aggregate(exons=hl.agg.collect(exons.row_value))

    transcripts = transcripts.annotate(
        exons=collect_transcript_exons(transcript_exons[transcripts.transcript_id].exons)
    )

    # Annotate genes with their transcripts
    gene_transcripts = transcripts.key_by()
    gene_transcripts = gene_transcripts.group_by(gene_transcripts.gene_id).aggregate(
        transcripts=hl.agg.collect(gene_transcripts.row_value)
    )
    genes = genes.annotate(**gene_transcripts[genes.gene_id])
    genes = genes.cache()

    return genes


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--gencode",
        action="append",
        default=[],
        metavar=("version", "gtf_path", "canonical_transcripts_path"),
        nargs=3,
        required=True,
    )
    parser.add_argument("--hgnc")
    parser.add_argument("--mane-select-transcripts")
    parser.add_argument("--min-partitions", type=int, default=32)
    parser.add_argument("--output", required=True)
    args = parser.parse_args()

    genes = None

    all_gencode_versions = [gencode_version for gencode_version, _, _ in args.gencode]

    for gencode_version, gtf_path, canonical_transcripts_path in args.gencode:
        gencode_genes = load_gencode_gene_models(gtf_path, min_partitions=args.min_partitions)

        # Canonical transcripts file is a TSV with two columns: gene ID and transcript ID and no header row
        canonical_transcripts = hl.import_table(
            canonical_transcripts_path, key="gene_id", min_partitions=args.min_partitions
        )
        gencode_genes = gencode_genes.annotate(
            canonical_transcript_id=canonical_transcripts[gencode_genes.gene_id].transcript_id
        )

        gencode_genes = gencode_genes.select(**{f"v{gencode_version}": gencode_genes.row_value})

        if not genes:
            genes = gencode_genes
        else:
            genes = genes.join(gencode_genes, "outer")

    genes = genes.select(gencode=genes.row_value)

    hgnc = hl.import_table(args.hgnc, missing="")

    hgnc = hgnc.select(
        hgnc_id=hgnc["HGNC ID"],
        symbol=hgnc["Approved symbol"],
        name=hgnc["Approved name"],
        previous_symbols=hgnc["Previous symbols"],
        alias_symbols=hgnc["Alias symbols"],
        omim_id=hgnc["OMIM ID(supplied by OMIM)"],
        gene_id=hl.or_else(hgnc["Ensembl gene ID"], hgnc["Ensembl ID(supplied by Ensembl)"]),
    )
    hgnc = hgnc.filter(hl.is_defined(hgnc.gene_id)).key_by("gene_id")
    hgnc = hgnc.annotate(
        previous_symbols=hl.cond(
            hgnc.previous_symbols == "",
            hl.empty_array(hl.tstr),
            hgnc.previous_symbols.split(",").map(lambda s: s.strip()),
        ),
        alias_symbols=hl.cond(
            hgnc.alias_symbols == "", hl.empty_array(hl.tstr), hgnc.alias_symbols.split(",").map(lambda s: s.strip())
        ),
    )

    genes = genes.annotate(**hgnc[genes.gene_id])
    genes = genes.annotate(symbol_source=hl.cond(hl.is_defined(genes.symbol), "hgnc", hl.null(hl.tstr)))

    # If an HGNC gene symbol was not present, use the symbol from Gencode
    for gencode_version in all_gencode_versions:
        genes = genes.annotate(
            symbol=hl.or_else(genes.symbol, genes.gencode[f"v{gencode_version}"].gene_symbol),
            symbol_source=hl.cond(
                hl.is_missing(genes.symbol) & hl.is_defined(genes.gencode[f"v{gencode_version}"].gene_symbol),
                f"gencode (v{gencode_version})",
                genes.symbol_source,
            ),
        )

    # Collect all fields that can be used to search by gene name
    genes = genes.annotate(
        symbol_upper_case=genes.symbol.upper(),
        search_terms=hl.empty_array(hl.tstr)
        .append(genes.symbol)
        .extend(genes.previous_symbols)
        .extend(genes.alias_symbols),
    )
    for gencode_version in all_gencode_versions:
        genes = genes.annotate(
            search_terms=hl.rbind(
                genes.gencode[f"v{gencode_version}"].gene_symbol,
                lambda symbol_in_gencode: hl.cond(
                    hl.is_defined(symbol_in_gencode), genes.search_terms.append(symbol_in_gencode), genes.search_terms
                ),
            )
        )

    genes = genes.annotate(search_terms=hl.set(genes.search_terms.map(lambda s: s.upper())))

    if args.mane_select_transcripts:
        mane_select_transcripts = hl.import_table(args.mane_select_transcripts, force=True)
        mane_select_transcripts = mane_select_transcripts.select(
            gene_id=mane_select_transcripts.Ensembl_Gene.split("\\.")[0],
            matched_gene_version=mane_select_transcripts.Ensembl_Gene.split("\\.")[1],
            ensembl_id=mane_select_transcripts.Ensembl_nuc.split("\\.")[0],
            ensembl_version=mane_select_transcripts.Ensembl_nuc.split("\\.")[1],
            refseq_id=mane_select_transcripts.RefSeq_nuc.split("\\.")[0],
            refseq_version=mane_select_transcripts.RefSeq_nuc.split("\\.")[1],
        )
        mane_select_transcripts = mane_select_transcripts.key_by("gene_id")

        # For GRCh38 (Gencode >= 20) transcripts, use the MANE Select transcripts to annotate transcripts
        # with their matching RefSeq transcript.
        ensembl_to_refseq_map = {}
        for transcript in mane_select_transcripts.collect():
            ensembl_to_refseq_map[transcript.ensembl_id] = {
                transcript.ensembl_version: hl.Struct(
                    refseq_id=transcript.refseq_id, refseq_version=transcript.refseq_version
                )
            }

        ensembl_to_refseq_map = hl.literal(ensembl_to_refseq_map)

        for gencode_version in ["19", "29"]:
            if int(gencode_version) >= 20:
                transcript_annotation = lambda transcript: transcript.annotate(
                    **ensembl_to_refseq_map.get(
                        transcript.transcript_id,
                        hl.empty_dict(hl.tstr, hl.tstruct(refseq_id=hl.tstr, refseq_version=hl.tstr)),
                    ).get(
                        transcript.transcript_version,
                        hl.struct(refseq_id=hl.null(hl.tstr), refseq_version=hl.null(hl.tstr)),
                    )
                )
            else:
                transcript_annotation = lambda transcript: transcript.annotate(
                    refseq_id=hl.null(hl.tstr), refseq_version=hl.null(hl.tstr)
                )

            genes = genes.annotate(
                gencode=genes.gencode.annotate(
                    **{
                        f"v{gencode_version}": genes.gencode[f"v{gencode_version}"].annotate(
                            transcripts=genes.gencode[f"v{gencode_version}"].transcripts.map(transcript_annotation)
                        )
                    }
                )
            )

        # Annotate genes with their MANE Select transcript
        genes = genes.annotate(mane_select_transcript=mane_select_transcripts[genes.gene_id])

    genes.describe()

    genes.write(args.output, overwrite=True)


if __name__ == "__main__":
    main()

Thanks, will look through this!

thanks Tim!

Actually this suggestion solves it:

" If running Hail locally, memory can be increased by setting the PYSPARK_SUBMIT_ARGS environment variable. For example,

PYSPARK_SUBMIT_ARGS="--driver-memory 4g pyspark-shell" python script.py

How do I increase the memory or RAM available to the JVM when I start Hail through Python?"

You’re not running Hail in local mode – you’re running on Dataproc, so I wouldn’t expect this environment variable to affect the memory behavior.

I was running it in local mode. Sorry didn’t mention this before.

Oh, okay, great!