Hi Guys,
Do you have a recommendation for how to best export SNP dosages from a MatrixTable to a numpy/pandas object?
The goal is to filter subjects and create a numpy matrix of 0/1/2 alt allele counts for use as input to a TensorFlow model. The overall data size isn’t astronomical and can fit into RAM.
Subject filtering isn’t an issue, but I’ve been toying with the “n_alt_alleles” function, combined with collect() / take(), and haven’t identified a good solution.
I would appreciate any thoughts or advice you might have!
Hi @tpoterba, I have the same question, except my data IS astronomical and does not fit into RAM, at all.
My current method is to use n_alt_alleles, convert the MatrixTable into a Table with columns indexed by sample id, and then collect each sample one by one and convert into a Tensorflow Example object and write to a TFRecord. This works but is not the most efficient (looking into parallelising it). Wondering if you have any suggestions for a better pipeline for this.
I think there will be better strategies than the one you’re employing now, but nothing immediately jumps to mind. I’ll ruminate on this a bit and respond this afternoon or tomorrow.
Hi @tpoterba, wondering if you’ve had a chance to give this more thought? I still haven’t come up with a better strategy myself and even parallelised, this is taking forever.
Ah, thanks for the bump! I’ve thought about this a bit, but don’t have any great solution. Part of why this is hard is that you’re doing a transpose – Hail MatrixTables are stored and processed row-major, but you want to export columns to TF individually. I think perhaps you can improve the speed a bit by collecting samples in groups (as many can fit in memory at once), something like the following:
ht = mt.select_entries(n = mt.GT.n_alt_alleles()).localize_entries('gts')
# destructure array of structs
ht = ht.select(gts = ht.gts.map(lambda x: x['n']))
# checkpoint so there's minimal work to do on each iteration
ht = ht.checkpoint('...')
# now loop through in groups of `GROUP_SIZE` samples
for group_idx in range(N_SAMPLES / GROUP_SIZE):
start = group_idx * GROUP_SIZE
end = start + GROUP_SIZE
local_data = ht.gts[start:end].collect()
... process this and send to TF...
Thanks for the response! I’m not sure I quite follow what you’re suggesting. Would this be effectively transposing the whole thing first (by making it into an array of structs)? Also what is the checkpoint? Unfortunately the batches I can fit in memory seem to be like 3 samples, which doesn’t help much.
I had originally considered getting the data out of the hail format and feeding it into a tensorflow pipeline with a custom generator, but I worried (and still do) that it would be extremely slow due to hail’s lazy evaluation. It seemed like it would be better to convert all the data ahead of time into something that tensorflow could read quickly, like a tfrecord, and then at least my model training would be fast, but I underestimated how incredibly slow tensorflow’s io class is. Another option is to use HDF5, but I’m not very familiar with them and I don’t know whether they’re any easier to convert into.
Generally I’m really stuck on this so any ideas are extremely appreciated.
import hail as hl
import hailtop
mt = hl.read_matrix_table('...')
AS_BLOCK_MATRIX_FILE = '...'
hl.BlockMatrix.write_from_entry_expr(mt.GT.n_alt_alleles(),
AS_BLOCK_MATRIX_FILE)
bm = hl.BlockMatrix.read(AS_BLOCK_MATRIX_FILE)
GROUP_SIZE = 4096
column_groups = hailtop.utils.grouped(GROUP_SIZE, list(range(mt.count_cols()))
for col_group in column_groups:
x = bm.filter_cols(col_group).to_numpy()
... use the column vector, x, anyway you like ...
A BlockMatrix is partitioned in both rows and columns so you can (relatively) efficiently read a span of columns. The default block size is 4096x4096, so reading groups of 4096 columns is the most efficient option. If you can’t fit 4096 column vectors in memory, try smaller powers of two.
Thanks everyone. So I gave it a go and I kept getting OOM errors even when I reduced the group size by several factors of 2. So I decided to try running it with group_size 32 just to make sure it worked, with the plan to increase the group size as much as possible from there once I’d established that it worked. I just wanted to see if/how fast it loaded each group so I just put a print in the loop for now. This is the error that I got:
FatalError Traceback (most recent call last)
<ipython-input-5-3a08ca5bd72f> in <module>
1 for col_group in column_groups:
----> 2 x = bm.filter_cols(col_group).to_numpy()
3 print(x)
<decorator-gen-1469> in to_numpy(self, _force_blocking)
~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
612 def wrapper(__original_func, *args, **kwargs):
613 args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 614 return __original_func(*args_, **kwargs_)
615
616 return wrapper
~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/linalg/blockmatrix.py in to_numpy(self, _force_blocking)
1196 path = new_local_temp_file()
1197 uri = local_path_uri(path)
-> 1198 self.tofile(uri)
1199 return np.fromfile(path).reshape((self.n_rows, self.n_cols))
1200
<decorator-gen-1467> in tofile(self, uri)
~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
612 def wrapper(__original_func, *args, **kwargs):
613 args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 614 return __original_func(*args_, **kwargs_)
615
616 return wrapper
~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/linalg/blockmatrix.py in tofile(self, uri)
1168
1169 writer = BlockMatrixBinaryWriter(uri)
-> 1170 Env.backend().execute(BlockMatrixWrite(self._bmir, writer))
1171
1172 @typecheck_method(_force_blocking=bool)
~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/backend/spark_backend.py in execute(self, ir, timed)
294 jir = self._to_java_value_ir(ir)
295 # print(self._hail_package.expr.ir.Pretty.apply(jir, True, -1))
--> 296 result = json.loads(self._jhc.backend().executeJSON(jir))
297 value = ir.typ._from_json(result['value'])
298 timings = result['timings']
~/anaconda3/envs/hailtf/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/backend/spark_backend.py in deco(*args, **kwargs)
39 raise FatalError('%s\n\nJava stack trace:\n%s\n'
40 'Hail version: %s\n'
---> 41 'Error summary: %s' % (deepest, full, hail.__version__, deepest)) from None
42 except pyspark.sql.utils.CapturedException as e:
43 raise FatalError('%s\n\nJava stack trace:\n%s\n'
FatalError: SparkException: Job 0 cancelled because SparkContext was shut down
Java stack trace:
org.apache.spark.SparkException: Job 0 cancelled because SparkContext was shut down
at org.apache.spark.scheduler.DAGScheduler$$anonfun$cleanUpAfterSchedulerStop$1.apply(DAGScheduler.scala:932)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$cleanUpAfterSchedulerStop$1.apply(DAGScheduler.scala:930)
at scala.collection.mutable.HashSet.foreach(HashSet.scala:78)
at org.apache.spark.scheduler.DAGScheduler.cleanUpAfterSchedulerStop(DAGScheduler.scala:930)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onStop(DAGScheduler.scala:2128)
at org.apache.spark.util.EventLoop.stop(EventLoop.scala:84)
at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:2041)
at org.apache.spark.SparkContext$$anonfun$stop$6.apply$mcV$sp(SparkContext.scala:1949)
at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1340)
at org.apache.spark.SparkContext.stop(SparkContext.scala:1948)
at org.apache.spark.SparkContext$$anonfun$2.apply$mcV$sp(SparkContext.scala:575)
at org.apache.spark.util.SparkShutdownHook.run(ShutdownHookManager.scala:216)
at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply$mcV$sp(ShutdownHookManager.scala:188)
at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply(ShutdownHookManager.scala:188)
at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply(ShutdownHookManager.scala:188)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945)
at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply$mcV$sp(ShutdownHookManager.scala:188)
at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply(ShutdownHookManager.scala:188)
at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply(ShutdownHookManager.scala:188)
at scala.util.Try$.apply(Try.scala:192)
at org.apache.spark.util.SparkShutdownHookManager.runAll(ShutdownHookManager.scala:188)
at org.apache.spark.util.SparkShutdownHookManager$$anon$2.run(ShutdownHookManager.scala:178)
at org.apache.hadoop.util.ShutdownHookManager$1.run(ShutdownHookManager.java:54)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
at org.apache.spark.rdd.RDD.collect(RDD.scala:944)
at is.hail.linalg.BlockMatrix.toBreezeMatrix(BlockMatrix.scala:921)
at is.hail.expr.ir.BlockMatrixBinaryWriter.apply(BlockMatrixWriter.scala:116)
at is.hail.expr.ir.Interpret$.run(Interpret.scala:821)
at is.hail.expr.ir.Interpret$.alreadyLowered(Interpret.scala:53)
at is.hail.expr.ir.InterpretNonCompilable$.interpretAndCoerce$1(InterpretNonCompilable.scala:16)
at is.hail.expr.ir.InterpretNonCompilable$.is$hail$expr$ir$InterpretNonCompilable$$rewrite$1(InterpretNonCompilable.scala:53)
at is.hail.expr.ir.InterpretNonCompilable$.apply(InterpretNonCompilable.scala:58)
at is.hail.expr.ir.lowering.InterpretNonCompilablePass$.transform(LoweringPass.scala:66)
at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3$$anonfun$1.apply(LoweringPass.scala:15)
at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3$$anonfun$1.apply(LoweringPass.scala:15)
at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:69)
at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3.apply(LoweringPass.scala:15)
at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3.apply(LoweringPass.scala:13)
at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:69)
at is.hail.expr.ir.lowering.LoweringPass$class.apply(LoweringPass.scala:13)
at is.hail.expr.ir.lowering.InterpretNonCompilablePass$.apply(LoweringPass.scala:61)
at is.hail.expr.ir.lowering.LoweringPipeline$$anonfun$apply$1.apply(LoweringPipeline.scala:14)
at is.hail.expr.ir.lowering.LoweringPipeline$$anonfun$apply$1.apply(LoweringPipeline.scala:12)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:35)
at is.hail.expr.ir.lowering.LoweringPipeline.apply(LoweringPipeline.scala:12)
at is.hail.expr.ir.CompileAndEvaluate$._apply(CompileAndEvaluate.scala:28)
at is.hail.backend.spark.SparkBackend.is$hail$backend$spark$SparkBackend$$_execute(SparkBackend.scala:318)
at is.hail.backend.spark.SparkBackend$$anonfun$execute$1.apply(SparkBackend.scala:305)
at is.hail.backend.spark.SparkBackend$$anonfun$execute$1.apply(SparkBackend.scala:304)
at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:20)
at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:18)
at is.hail.utils.package$.using(package.scala:602)
at is.hail.annotations.Region$.scoped(Region.scala:18)
at is.hail.expr.ir.ExecuteContext$.scoped(ExecuteContext.scala:18)
at is.hail.backend.spark.SparkBackend.withExecuteContext(SparkBackend.scala:230)
at is.hail.backend.spark.SparkBackend.execute(SparkBackend.scala:304)
at is.hail.backend.spark.SparkBackend.executeJSON(SparkBackend.scala:324)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Hail version: 0.2.50-32fc1de02d32
Error summary: SparkException: Job 0 cancelled because SparkContext was shut down
Oops sorry, I missed your reply. Do you have the hail log file? That has the true error message that we need.
How many variants do you have? I’m surprised you’re running out of memory. Do you have your memory set to the maximum available? Java Heap Space out of memory
Wow I feel silly – I had reduced my Spark memory so I could write tfrecords in parallel and forgot to put it back. That fixed it and it can now do 4096 and it loads them pretty fast! Thank you so much for your help! I think this will make a viable data generator to use with Keras.
@danking : follow up question. I need to relate the entries to the phenotypes and obviously the BlockMatrix loses the sample IDs. Is sample ordering preserved in this transformation? If I extract the samples from the original matrixtable (e.g. with mt.s.collect) how does this ordering relate to the ordering of the BlockMatrix? I have 50K SNPs per sample in case that’s pertinent.
Edit to add: wondering if getting all 50K SNPs for each sample is going to be an issue because of the checkerboard partitions… ignore this as I have just realised that the column groups have all the SNPs already. So only question is, is sample ordering preserved and same as mt.s.collect?
Hello again, so I have successfully created a generator that gets numpy arrays of 4096 samples out of a blockmatrix pretty efficiently. The issue is that each time the generator is called on the blockmatrix, a new temporary file is created in /tmp, and over the course of training a model with many iterations this rapidly becomes a storage problem as they don’t seem to be deleted at the end of the training epoch. I’m wondering if there’s some way to reroute the temporary files to somewhere else with more space? Or make sure they’re deleted as soon as they are no longer needed (if this isn’t already happening)?
Oh, that’s great, thanks. Do you have any recommendations for me patching this in the meantime? I tried just routinely clearing /tmp every few minutes but I ended up deleting some stuff while it was still being used. Would clearing it at the end of each epoch work? How long are those files needed for?