How do I avoid OutOfMemory error when I try to join many Matrix Tables or VCFs using the entries method?

Hello,

I wanted to join 800 vcf file in a single Table and store it. This is the code:

table = None
for input_vcf in tqdm(input_vcfs):
# load vcf
mt = hl.methods.import_vcf(input_vcf, contig_recoding=recode, force_bgz=True, reference_genome=None)

        # clean information
        mt = (mt.select_entries(mt.GT, mt.DP, mt.GQ))
        mt = mt.select_rows()
        mt = mt.select_cols()
        
        valid = hl.is_valid_locus(mt.locus.contig, mt.locus.position, reference_genome='GRCh37')
        
        # mt_wrong = mt.filter_rows(~valid)
        mt_correct = mt.filter_rows(valid)
        
        mt_correct = mt_correct.annotate_entries(GT = mt_correct.GT[0]+mt_correct.GT[1])
        mt_correct = mt_correct.annotate_entries(GT = hl.coalesce(mt_correct.GT, -1))
        mt_correct = mt_correct.annotate_entries(DP = hl.coalesce(mt_correct.DP, 0))
        mt_correct = mt_correct.annotate_entries(GQ = hl.coalesce(mt_correct.GQ, 0))
        
        # store correct variants in MatrixTables
        if mt_correct.rows().count() > 0:
            if table is None:
                table = mt_correct.entries()
            else:
                table = table.join(mt_correct.entries())          
     
table.write('db/Table/table.ht', overwrite=True)

but after a while I got this error message:

java.lang.OutOfMemoryError: GC overhead limit exceeded
at scala.reflect.ManifestFactory$$anon$2.newArray(Manifest.scala:177)
at scala.reflect.ManifestFactory$$anon$2.newArray(Manifest.scala:176)
at scala.collection.mutable.WrappedArrayBuilder.mkArray(WrappedArrayBuilder.scala:46)
at scala.collection.mutable.WrappedArrayBuilder.resize(WrappedArrayBuilder.scala:53)
at scala.collection.mutable.WrappedArrayBuilder.sizeHint(WrappedArrayBuilder.scala:58)
at com.twitter.chill.WrappedArraySerializer.read(WrappedArraySerializer.scala:38)
at com.twitter.chill.WrappedArraySerializer.read(WrappedArraySerializer.scala:23)
at com.esotericsoftware.kryo.Kryo.readObject(Kryo.java:731)
at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.read(DefaultArraySerializers.java:391)
at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.read(DefaultArraySerializers.java:302)
at com.esotericsoftware.kryo.Kryo.readObject(Kryo.java:731)
at com.esotericsoftware.kryo.serializers.ObjectField.read(ObjectField.java:125)
at com.esotericsoftware.kryo.serializers.FieldSerializer.read(FieldSerializer.java:543)
at com.esotericsoftware.kryo.Kryo.readClassAndObject(Kryo.java:813)
at org.apache.spark.serializer.KryoDeserializationStream.readObject(KryoSerializer.scala:278)
at org.apache.spark.serializer.DeserializationStream.readKey(Serializer.scala:156)
at org.apache.spark.serializer.DeserializationStream$$anon$2.getNext(Serializer.scala:188)
at org.apache.spark.serializer.DeserializationStream$$anon$2.getNext(Serializer.scala:185)
at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:439)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:31)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:199)
at org.apache.spark.shuffle.BlockStoreShuffleReader.read(BlockStoreShuffleReader.scala:102)
at org.apache.spark.rdd.ShuffledRDD.compute(ShuffledRDD.scala:105)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)

Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
at is.hail.utils.package$.using(package.scala:609)
at is.hail.annotations.Region$.scoped(Region.scala:18)
at is.hail.expr.ir.ExecuteContext$.scoped(ExecuteContext.scala:18)
at is.hail.backend.spark.SparkBackend.withExecuteContext(SparkBackend.scala:230)
at is.hail.backend.spark.SparkBackend.execute(SparkBackend.scala:320)
at is.hail.backend.spark.SparkBackend.executeJSON(SparkBackend.scala:349)
at sun.reflect.GeneratedMethodAccessor66.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)

java.lang.OutOfMemoryError: GC overhead limit exceeded
at scala.reflect.ManifestFactory$$anon$2.newArray(Manifest.scala:177)
at scala.reflect.ManifestFactory$$anon$2.newArray(Manifest.scala:176)
at scala.collection.mutable.WrappedArrayBuilder.mkArray(WrappedArrayBuilder.scala:46)
at scala.collection.mutable.WrappedArrayBuilder.resize(WrappedArrayBuilder.scala:53)
at scala.collection.mutable.WrappedArrayBuilder.sizeHint(WrappedArrayBuilder.scala:58)
at com.twitter.chill.WrappedArraySerializer.read(WrappedArraySerializer.scala:38)
at com.twitter.chill.WrappedArraySerializer.read(WrappedArraySerializer.scala:23)
at com.esotericsoftware.kryo.Kryo.readObject(Kryo.java:731)
at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.read(DefaultArraySerializers.java:391)
at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.read(DefaultArraySerializers.java:302)
at com.esotericsoftware.kryo.Kryo.readObject(Kryo.java:731)
at com.esotericsoftware.kryo.serializers.ObjectField.read(ObjectField.java:125)
at com.esotericsoftware.kryo.serializers.FieldSerializer.read(FieldSerializer.java:543)
at com.esotericsoftware.kryo.Kryo.readClassAndObject(Kryo.java:813)
at org.apache.spark.serializer.KryoDeserializationStream.readObject(KryoSerializer.scala:278)
at org.apache.spark.serializer.DeserializationStream.readKey(Serializer.scala:156)
at org.apache.spark.serializer.DeserializationStream$$anon$2.getNext(Serializer.scala:188)
at org.apache.spark.serializer.DeserializationStream$$anon$2.getNext(Serializer.scala:185)
at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:439)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:31)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:199)
at org.apache.spark.shuffle.BlockStoreShuffleReader.read(BlockStoreShuffleReader.scala:102)
at org.apache.spark.rdd.ShuffledRDD.compute(ShuffledRDD.scala:105)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)

The entries method is very slow and uses a lot of memory. You should redesign your algorithm to note use the entries method.

VCFs do not typically contain combinable data. If, however, you really think this is the right thing to do, try something like this (I have not tested this code):

import hail as hl

def tree_outer_join(mts):
    n = len(mts)
    if n == 0:
        raise ValueError('cannot join empty list of MTs')
    if n == 1:
        mt = mts[0]
        mt = mt.select_rows(the_rows = [mt.row])
        mt = mt.select_cols(the_cols = [mt.col])
        mt = mt.select_entries(the_entries = [mt.entry])
        return mt
    if n == 2:
        mt = hl.experimental.full_outer_join_mt(mts[0], mts[1])
        mt = mt.select_rows(the_rows = [mt.left_row, mt.right_row])
        mt = mt.select_cols(the_cols = [mt.left_col, mt.right_col])
        mt = mt.select_entries(the_entries = [mt.left_entry, mt.right_entry])
        return mt
    assert n >= 3
    left = mts[:n//2]
    right = mts[n//2:]
    mt = hl.experimental.full_outer_join_mt(tree_outer_join(left),
                                              tree_outer_join(right))
    mt = mt.select_rows(the_rows = mt.left_row.the_rows.extend(mt.right_row.the_rows))
    mt = mt.select_cols(the_cols = mt.left_col.the_cols.extend(mt.right_col.the_cols))
    mt = mt.select_entries(the_entries = mt.left_entry.the_entries.extend(mt.right_entry.the_entries))
    return mt

mt = tree_outer_join([hl.import_vcf(fname) for fname in input_vcfs])

Now you have a MatrixTable with a row field, a column, and an entry field each of which is an array containing the corresponding fields from the constituent MatrixTables. For example:

In [20]: mt = tree_outer_join([hl.balding_nichols_model(1,3,3) for _ in range(10)]) 
    ...: mt.show()                                                                                                                                                                                                                                                                                                            
...
+---------------+------------+---------------------------------------------------------------+---------------------------------------------------------------+---------------------------------------------------------------+
| locus         | alleles    | 0.the_entries                                                 | 1.the_entries                                                 | 2.the_entries                                                 |
+---------------+------------+---------------------------------------------------------------+---------------------------------------------------------------+---------------------------------------------------------------+
| locus<GRCh37> | array<str> | array<struct{GT: call}>                                       | array<struct{GT: call}>                                       | array<struct{GT: call}>                                       |
+---------------+------------+---------------------------------------------------------------+---------------------------------------------------------------+---------------------------------------------------------------+
| 1:1           | ["A","C"]  | [(0/0),(0/0),(0/0),(0/1),(0/1),(1/1),(0/0),(0/1),(0/1),(1/1)] | [(0/1),(0/1),(0/0),(0/0),(1/1),(0/1),(0/0),(1/1),(0/1),(1/1)] | [(0/0),(1/1),(0/1),(1/1),(0/1),(1/1),(0/1),(0/0),(0/1),(1/1)] |
| 1:2           | ["A","C"]  | [(1/1),(0/1),(0/0),(0/1),(1/1),(1/1),(0/1),(1/1),(1/1),(1/1)] | [(1/1),(0/1),(0/1),(1/1),(1/1),(1/1),(1/1),(1/1),(1/1),(1/1)] | [(0/1),(1/1),(0/1),(0/1),(1/1),(1/1),(1/1),(0/1),(1/1),(1/1)] |
| 1:3           | ["A","C"]  | [(1/1),(0/1),(0/1),(1/1),(0/1),(0/1),(0/1),(1/1),(0/1),(0/1)] | [(0/1),(0/0),(0/1),(0/1),(0/0),(0/0),(0/0),(0/1),(0/1),(1/1)] | [(0/0),(0/1),(0/1),(0/1),(1/1),(0/1),(1/1),(0/1),(1/1),(0/1)] |
+---------------+------------+---------------------------------------------------------------+---------------------------------------------------------------+---------------------------------------------------------------+

I suspect this will still fail (800 VCFs is a lot of VCFs). I recommend combining some smaller number of VCFs, saving that as a matrix table, then combining those matrix tables. For example:

import hailtop.utils as hu
mts = [hl.import_vcf(fname) for fname in input_vcfs]

def write_it(i, mt_group):
    fname = f'gs://some/path/{i}'
    tree_outer_join(mt_group).write(fname)
    return fname

saved_filenames = [
    write_it(I, mt_group)
    for i, mt_group in enumerate(hu.grouped(50, mts))
]
mt = tree_outer_join([hl.read_matrix_table(fname) for fname in saved_filenames])