SNP dosages to numpy/pandas?

Hi Guys,
Do you have a recommendation for how to best export SNP dosages from a MatrixTable to a numpy/pandas object?

The goal is to filter subjects and create a numpy matrix of 0/1/2 alt allele counts for use as input to a TensorFlow model. The overall data size isn’t astronomical and can fit into RAM.

Subject filtering isn’t an issue, but I’ve been toying with the “n_alt_alleles” function, combined with collect() / take(), and haven’t identified a good solution.

I would appreciate any thoughts or advice you might have!

Thanks,
Kevin

try mt.GT.n_alt_alleles().export('path...') then ingesting that with numpy/pandas

Hi @tpoterba, I have the same question, except my data IS astronomical and does not fit into RAM, at all.

My current method is to use n_alt_alleles, convert the MatrixTable into a Table with columns indexed by sample id, and then collect each sample one by one and convert into a Tensorflow Example object and write to a TFRecord. This works but is not the most efficient (looking into parallelising it). Wondering if you have any suggestions for a better pipeline for this.

I think there will be better strategies than the one you’re employing now, but nothing immediately jumps to mind. I’ll ruminate on this a bit and respond this afternoon or tomorrow.

Hi @tpoterba, wondering if you’ve had a chance to give this more thought? I still haven’t come up with a better strategy myself and even parallelised, this is taking forever.

Ah, thanks for the bump! I’ve thought about this a bit, but don’t have any great solution. Part of why this is hard is that you’re doing a transpose – Hail MatrixTables are stored and processed row-major, but you want to export columns to TF individually. I think perhaps you can improve the speed a bit by collecting samples in groups (as many can fit in memory at once), something like the following:

ht = mt.select_entries(n = mt.GT.n_alt_alleles()).localize_entries('gts')

 # destructure array of structs
ht = ht.select(gts = ht.gts.map(lambda x: x['n']))

# checkpoint so there's minimal work to do on each iteration
ht = ht.checkpoint('...')

# now loop through in groups of `GROUP_SIZE` samples
for group_idx in range(N_SAMPLES / GROUP_SIZE):
    start = group_idx * GROUP_SIZE
    end = start + GROUP_SIZE
    local_data = ht.gts[start:end].collect()
    ... process this and send to TF...

Thanks for the response! I’m not sure I quite follow what you’re suggesting. Would this be effectively transposing the whole thing first (by making it into an array of structs)? Also what is the checkpoint? Unfortunately the batches I can fit in memory seem to be like 3 samples, which doesn’t help much.

I had originally considered getting the data out of the hail format and feeding it into a tensorflow pipeline with a custom generator, but I worried (and still do) that it would be extremely slow due to hail’s lazy evaluation. It seemed like it would be better to convert all the data ahead of time into something that tensorflow could read quickly, like a tfrecord, and then at least my model training would be fast, but I underestimated how incredibly slow tensorflow’s io class is. Another option is to use HDF5, but I’m not very familiar with them and I don’t know whether they’re any easier to convert into.

Generally I’m really stuck on this so any ideas are extremely appreciated.

So, I have a kind of crazy suggestion.

import hail as hl
import hailtop

mt = hl.read_matrix_table('...')

AS_BLOCK_MATRIX_FILE = '...'

hl.BlockMatrix.write_from_entry_expr(mt.GT.n_alt_alleles(),
                                  AS_BLOCK_MATRIX_FILE)

bm = hl.BlockMatrix.read(AS_BLOCK_MATRIX_FILE)

GROUP_SIZE = 4096

column_groups = hailtop.utils.grouped(GROUP_SIZE, list(range(mt.count_cols()))
for col_group in column_groups:
    x = bm.filter_cols(col_group).to_numpy()
    ... use the column vector, x, anyway you like ...

A BlockMatrix is partitioned in both rows and columns so you can (relatively) efficiently read a span of columns. The default block size is 4096x4096, so reading groups of 4096 columns is the most efficient option. If you can’t fit 4096 column vectors in memory, try smaller powers of two.

1 Like

This sounds really intriguing! I will give it a go and let you know what happens…

@danking I get this:

AttributeError: module 'hail' has no attribute 'BlockMatrix'

Should be hl.linalg.BlockMatrix. Dan made a typo. See https://hail.is/docs/0.2/linalg/hail.linalg.BlockMatrix.html#blockmatrix

Thanks everyone. So I gave it a go and I kept getting OOM errors even when I reduced the group size by several factors of 2. So I decided to try running it with group_size 32 just to make sure it worked, with the plan to increase the group size as much as possible from there once I’d established that it worked. I just wanted to see if/how fast it loaded each group so I just put a print in the loop for now. This is the error that I got:

FatalError                                Traceback (most recent call last)
<ipython-input-5-3a08ca5bd72f> in <module>
  1 for col_group in column_groups:
----> 2     x = bm.filter_cols(col_group).to_numpy()
  3     print(x)

    <decorator-gen-1469> in to_numpy(self, _force_blocking)

~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
612     def wrapper(__original_func, *args, **kwargs):
613         args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 614         return __original_func(*args_, **kwargs_)
615 
616     return wrapper

~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/linalg/blockmatrix.py in to_numpy(self, _force_blocking)
   1196         path = new_local_temp_file()
   1197         uri = local_path_uri(path)
-> 1198         self.tofile(uri)
   1199         return np.fromfile(path).reshape((self.n_rows, self.n_cols))
   1200 

<decorator-gen-1467> in tofile(self, uri)

~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
612     def wrapper(__original_func, *args, **kwargs):
613         args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 614         return __original_func(*args_, **kwargs_)
615 
616     return wrapper

~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/linalg/blockmatrix.py in tofile(self, uri)
   1168 
   1169         writer = BlockMatrixBinaryWriter(uri)
-> 1170         Env.backend().execute(BlockMatrixWrite(self._bmir, writer))
   1171 
   1172     @typecheck_method(_force_blocking=bool)

~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/backend/spark_backend.py in execute(self, ir, timed)
294         jir = self._to_java_value_ir(ir)
295         # print(self._hail_package.expr.ir.Pretty.apply(jir, True, -1))
--> 296         result = json.loads(self._jhc.backend().executeJSON(jir))
297         value = ir.typ._from_json(result['value'])
298         timings = result['timings']

~/anaconda3/envs/hailtf/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

~/anaconda3/envs/hailtf/lib/python3.6/site-packages/hail/backend/spark_backend.py in deco(*args, **kwargs)
 39             raise FatalError('%s\n\nJava stack trace:\n%s\n'
 40                              'Hail version: %s\n'
---> 41                              'Error summary: %s' % (deepest, full, hail.__version__, deepest)) from None
 42         except pyspark.sql.utils.CapturedException as e:
 43             raise FatalError('%s\n\nJava stack trace:\n%s\n'

FatalError: SparkException: Job 0 cancelled because SparkContext was shut down

Java stack trace:
org.apache.spark.SparkException: Job 0 cancelled because SparkContext was shut down
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$cleanUpAfterSchedulerStop$1.apply(DAGScheduler.scala:932)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$cleanUpAfterSchedulerStop$1.apply(DAGScheduler.scala:930)
	at scala.collection.mutable.HashSet.foreach(HashSet.scala:78)
	at org.apache.spark.scheduler.DAGScheduler.cleanUpAfterSchedulerStop(DAGScheduler.scala:930)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onStop(DAGScheduler.scala:2128)
	at org.apache.spark.util.EventLoop.stop(EventLoop.scala:84)
	at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:2041)
	at org.apache.spark.SparkContext$$anonfun$stop$6.apply$mcV$sp(SparkContext.scala:1949)
	at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1340)
	at org.apache.spark.SparkContext.stop(SparkContext.scala:1948)
	at org.apache.spark.SparkContext$$anonfun$2.apply$mcV$sp(SparkContext.scala:575)
	at org.apache.spark.util.SparkShutdownHook.run(ShutdownHookManager.scala:216)
	at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply$mcV$sp(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1$$anonfun$apply$mcV$sp$1.apply(ShutdownHookManager.scala:188)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945)
	at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply$mcV$sp(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anonfun$runAll$1.apply(ShutdownHookManager.scala:188)
	at scala.util.Try$.apply(Try.scala:192)
	at org.apache.spark.util.SparkShutdownHookManager.runAll(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anon$2.run(ShutdownHookManager.scala:178)
	at org.apache.hadoop.util.ShutdownHookManager$1.run(ShutdownHookManager.java:54)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
	at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:944)
	at is.hail.linalg.BlockMatrix.toBreezeMatrix(BlockMatrix.scala:921)
	at is.hail.expr.ir.BlockMatrixBinaryWriter.apply(BlockMatrixWriter.scala:116)
	at is.hail.expr.ir.Interpret$.run(Interpret.scala:821)
	at is.hail.expr.ir.Interpret$.alreadyLowered(Interpret.scala:53)
	at is.hail.expr.ir.InterpretNonCompilable$.interpretAndCoerce$1(InterpretNonCompilable.scala:16)
	at is.hail.expr.ir.InterpretNonCompilable$.is$hail$expr$ir$InterpretNonCompilable$$rewrite$1(InterpretNonCompilable.scala:53)
	at is.hail.expr.ir.InterpretNonCompilable$.apply(InterpretNonCompilable.scala:58)
	at is.hail.expr.ir.lowering.InterpretNonCompilablePass$.transform(LoweringPass.scala:66)
	at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3$$anonfun$1.apply(LoweringPass.scala:15)
	at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3$$anonfun$1.apply(LoweringPass.scala:15)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:69)
	at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3.apply(LoweringPass.scala:15)
	at is.hail.expr.ir.lowering.LoweringPass$$anonfun$apply$3.apply(LoweringPass.scala:13)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:69)
	at is.hail.expr.ir.lowering.LoweringPass$class.apply(LoweringPass.scala:13)
	at is.hail.expr.ir.lowering.InterpretNonCompilablePass$.apply(LoweringPass.scala:61)
	at is.hail.expr.ir.lowering.LoweringPipeline$$anonfun$apply$1.apply(LoweringPipeline.scala:14)
	at is.hail.expr.ir.lowering.LoweringPipeline$$anonfun$apply$1.apply(LoweringPipeline.scala:12)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:35)
	at is.hail.expr.ir.lowering.LoweringPipeline.apply(LoweringPipeline.scala:12)
	at is.hail.expr.ir.CompileAndEvaluate$._apply(CompileAndEvaluate.scala:28)
	at is.hail.backend.spark.SparkBackend.is$hail$backend$spark$SparkBackend$$_execute(SparkBackend.scala:318)
	at is.hail.backend.spark.SparkBackend$$anonfun$execute$1.apply(SparkBackend.scala:305)
	at is.hail.backend.spark.SparkBackend$$anonfun$execute$1.apply(SparkBackend.scala:304)
	at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:20)
	at is.hail.expr.ir.ExecuteContext$$anonfun$scoped$1.apply(ExecuteContext.scala:18)
	at is.hail.utils.package$.using(package.scala:602)
	at is.hail.annotations.Region$.scoped(Region.scala:18)
	at is.hail.expr.ir.ExecuteContext$.scoped(ExecuteContext.scala:18)
	at is.hail.backend.spark.SparkBackend.withExecuteContext(SparkBackend.scala:230)
	at is.hail.backend.spark.SparkBackend.execute(SparkBackend.scala:304)
	at is.hail.backend.spark.SparkBackend.executeJSON(SparkBackend.scala:324)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)



Hail version: 0.2.50-32fc1de02d32
Error summary: SparkException: Job 0 cancelled because SparkContext was shut down

@danking @johnc1231 @tpoterba any thoughts on what’s gone wrong here? I’m quite keen to use this blockmatrix approach if at all possible!

Oops sorry, I missed your reply. Do you have the hail log file? That has the true error message that we need.

How many variants do you have? I’m surprised you’re running out of memory. Do you have your memory set to the maximum available? Java Heap Space out of memory

1 Like

Wow I feel silly – I had reduced my Spark memory so I could write tfrecords in parallel and forgot to put it back. That fixed it and it can now do 4096 and it loads them pretty fast! Thank you so much for your help! I think this will make a viable data generator to use with Keras.

@danking : follow up question. I need to relate the entries to the phenotypes and obviously the BlockMatrix loses the sample IDs. Is sample ordering preserved in this transformation? If I extract the samples from the original matrixtable (e.g. with mt.s.collect) how does this ordering relate to the ordering of the BlockMatrix? I have 50K SNPs per sample in case that’s pertinent.

Edit to add: wondering if getting all 50K SNPs for each sample is going to be an issue because of the checkerboard partitions… ignore this as I have just realised that the column groups have all the SNPs already. So only question is, is sample ordering preserved and same as mt.s.collect?

Yes the ordering is preserved.

1 Like

Hello again, so I have successfully created a generator that gets numpy arrays of 4096 samples out of a blockmatrix pretty efficiently. The issue is that each time the generator is called on the blockmatrix, a new temporary file is created in /tmp, and over the course of training a model with many iterations this rapidly becomes a storage problem as they don’t seem to be deleted at the end of the training epoch. I’m wondering if there’s some way to reroute the temporary files to somewhere else with more space? Or make sure they’re deleted as soon as they are no longer needed (if this isn’t already happening)?

1 Like

Oops, sorry! I made a pull request to eagerly delete those files. We should release a new version of Hail soon!

1 Like

Oh, that’s great, thanks. Do you have any recommendations for me patching this in the meantime? I tried just routinely clearing /tmp every few minutes but I ended up deleting some stuff while it was still being used. Would clearing it at the end of each epoch work? How long are those files needed for?