Dephase genotypes

Hello,
hail version: 0.2.107-2387bb00ceee

This is sort of related to [query] bug in group_cols_by · Issue #13287 · hail-is/hail · GitHub

In the All of Us dataset we have some random genotypes that are phased by the DRAGEN pipeline. This is not ideal as downstream haplotype callers are unable to assign appropriate haplotypes because the entire gene isn’t phased/unphased. I would like to dephase these genotypes. I am able to replicate the issue with the public gnomAD hgdp and 1k gnomes data.

What’s interesting is that if I comment out one or both of the annotate_entries or filter_entries lines then everything runs fine. But if both are kept, then I get the error attached.

def split_multi_allelic(mt):
    bi = mt.filter_rows(hl.len(mt.alleles) == 2)
    bi = bi.annotate_rows(a_index=1, was_split=False)
    multi = mt.filter_rows(hl.len(mt.alleles) > 2)
    split = hl.split_multi_hts(multi)
    mt = split.union_rows(bi)
    return mt

def create_intervals(data):
    return [
        hl.Locus(chromosome, start, reference_genome="GRCh38")
        for _, (chromosome, start) in data[["CHROM", "POS"]].iterrows()
    ]
mt_gnomad_file = "gs://gcp-public-data--gnomad/release/3.1.2/mt/genomes/gnomad.genomes.v3.1.2.hgdp_1kg_subset_dense.mt"
mt = hl.read_matrix_table(mt_gnomad_file)

dpyd_positions = pd.DataFrame({"CHROM":['chr1', 'chr1'], "POS":[97573863, 97579893]})
intervals = create_intervals(dpyd_positions)
mt = mt.filter_rows(hl.literal(intervals).contains(mt.locus))

mt = split_multi_allelic(mt)

# Commenting out this line and everything runs fine
mt = mt.annotate_entries(GT=hl.if_else(mt.GT.phased, mt.GT.unphase(), mt.GT, missing_false=True))

# Commenting out this line and everything also runs fine
mt = mt.filter_entries(mt.GQ >= 20)

sample_ids = mt.s.collect()
ancestry_groups = np.random.choice(['ancestry_1', 'ancestry_2', 'ancestry_3', 'ancestry_4'], size=len(sample_ids))
ancestry = pd.DataFrame({"s":sample_ids, 'ancestry':ancestry_groups})
ancestry_table = hl.Table.from_pandas(ancestry, key='s')
mt = mt.annotate_cols(ancestry = ancestry_table[mt.s].ancestry)

mt_hwe_vals = mt.group_cols_by(mt.ancestry).aggregate(hwe = hl.agg.hardy_weinberg_test(mt.GT))
mt_hwe_vals = mt_hwe_vals.select_rows().select_cols() # drop irrelevant row and column fields
mt_hwe_vals.write(bucket + '/hwe_gnomad.ht', overwrite=True)
[Stage 1:>                                                          (0 + 0) / 1]
---------------------------------------------------------------------------
FatalError                                Traceback (most recent call last)
/tmp/ipykernel_3710/1120938232.py in <module>
     20 mt_hwe_vals = mt.group_cols_by(mt.ancestry).aggregate(hwe = hl.agg.hardy_weinberg_test(mt.GT))
     21 mt_hwe_vals = mt_hwe_vals.select_rows().select_cols() # drop irrelevant row and column fields
---> 22 mt_hwe_vals.write(bucket + '/hwe_gnomad.ht', overwrite=True)

<decorator-gen-1336> in write(self, output, overwrite, stage_locally, _codec_spec, _partitions, _checkpoint_file)

/opt/conda/lib/python3.7/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
    575     def wrapper(__original_func, *args, **kwargs):
    576         args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 577         return __original_func(*args_, **kwargs_)
    578 
    579     return wrapper

/opt/conda/lib/python3.7/site-packages/hail/matrixtable.py in write(self, output, overwrite, stage_locally, _codec_spec, _partitions, _checkpoint_file)
   2582 
   2583         writer = ir.MatrixNativeWriter(output, overwrite, stage_locally, _codec_spec, _partitions, _partitions_type, _checkpoint_file)
-> 2584         Env.backend().execute(ir.MatrixWrite(self._mir, writer))
   2585 
   2586     class _Show:

/opt/conda/lib/python3.7/site-packages/hail/backend/py4j_backend.py in execute(self, ir, timed)
    103             return (value, timings) if timed else value
    104         except FatalError as e:
--> 105             raise e.maybe_user_error(ir) from None
    106 
    107     async def _async_execute(self, ir, timed=False):

/opt/conda/lib/python3.7/site-packages/hail/backend/py4j_backend.py in execute(self, ir, timed)
     97         # print(self._hail_package.expr.ir.Pretty.apply(jir, True, -1))
     98         try:
---> 99             result_tuple = self._jbackend.executeEncode(jir, stream_codec, timed)
    100             (result, timings) = (result_tuple._1(), result_tuple._2())
    101             value = ir.typ._from_encoding(result)

/opt/conda/lib/python3.7/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1321         answer = self.gateway_client.send_command(command)
   1322         return_value = get_return_value(
-> 1323             answer, self.gateway_client, self.target_id, self.name)
   1324 
   1325         for temp_arg in temp_args:

/opt/conda/lib/python3.7/site-packages/hail/backend/py4j_backend.py in deco(*args, **kwargs)
     29             tpl = Env.jutils().handleForPython(e.java_exception)
     30             deepest, full, error_id = tpl._1(), tpl._2(), tpl._3()
---> 31             raise fatal_error_from_java_error_triplet(deepest, full, error_id) from None
     32         except pyspark.sql.utils.CapturedException as e:
     33             raise FatalError('%s\n\nJava stack trace:\n%s\n'

FatalError: SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 4 times, most recent failure: Lost task 0.3 in stage 1.0 (TID 4) (all-of-us-56-w-0.c.terra-vpc-sc-8f5cdfd2.internal executor 4): ExecutorLostFailure (executor 4 exited caused by one of the running tasks) Reason: Container from a bad node: container_1692191509495_0003_01_000004 on host: all-of-us-56-w-0.c.terra-vpc-sc-8f5cdfd2.internal. Exit status: 137. Diagnostics: [2023-08-16 14:47:55.966]Container killed on request. Exit code is 137
[2023-08-16 14:47:55.967]Container exited with a non-zero exit code 137. 
[2023-08-16 14:47:55.968]Killed by external signal
.
Driver stacktrace:

Java stack trace:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 4 times, most recent failure: Lost task 0.3 in stage 1.0 (TID 4) (all-of-us-56-w-0.c.terra-vpc-sc-8f5cdfd2.internal executor 4): ExecutorLostFailure (executor 4 exited caused by one of the running tasks) Reason: Container from a bad node: container_1692191509495_0003_01_000004 on host: all-of-us-56-w-0.c.terra-vpc-sc-8f5cdfd2.internal. Exit status: 137. Diagnostics: [2023-08-16 14:47:55.966]Container killed on request. Exit code is 137
[2023-08-16 14:47:55.967]Container exited with a non-zero exit code 137. 
[2023-08-16 14:47:55.968]Killed by external signal
.
Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2304)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2253)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2252)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2252)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1124)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1124)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1124)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2491)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2433)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2422)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:902)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2204)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2225)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2244)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2269)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1030)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:414)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1029)
	at is.hail.backend.spark.SparkBackend.parallelizeAndComputeWithIndex(SparkBackend.scala:355)
	at is.hail.backend.BackendUtils.collectDArray(BackendUtils.scala:43)
	at __C1279Compiled.__m1562split_CollectDistributedArray(Emit.scala)
	at __C1279Compiled.__m1507split_Let(Emit.scala)
	at __C1279Compiled.apply(Emit.scala)
	at is.hail.expr.ir.CompileAndEvaluate$.$anonfun$_apply$3(CompileAndEvaluate.scala:57)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:81)
	at is.hail.expr.ir.CompileAndEvaluate$._apply(CompileAndEvaluate.scala:57)
	at is.hail.expr.ir.CompileAndEvaluate$.evalToIR(CompileAndEvaluate.scala:30)
	at is.hail.expr.ir.LowerOrInterpretNonCompilable$.evaluate$1(LowerOrInterpretNonCompilable.scala:30)
	at is.hail.expr.ir.LowerOrInterpretNonCompilable$.rewrite$1(LowerOrInterpretNonCompilable.scala:67)
	at is.hail.expr.ir.LowerOrInterpretNonCompilable$.apply(LowerOrInterpretNonCompilable.scala:72)
	at is.hail.expr.ir.lowering.LowerOrInterpretNonCompilablePass$.transform(LoweringPass.scala:69)
	at is.hail.expr.ir.lowering.LoweringPass.$anonfun$apply$3(LoweringPass.scala:16)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:81)
	at is.hail.expr.ir.lowering.LoweringPass.$anonfun$apply$1(LoweringPass.scala:16)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:81)
	at is.hail.expr.ir.lowering.LoweringPass.apply(LoweringPass.scala:14)
	at is.hail.expr.ir.lowering.LoweringPass.apply$(LoweringPass.scala:13)
	at is.hail.expr.ir.lowering.LowerOrInterpretNonCompilablePass$.apply(LoweringPass.scala:64)
	at is.hail.expr.ir.lowering.LoweringPipeline.$anonfun$apply$1(LoweringPipeline.scala:15)
	at is.hail.expr.ir.lowering.LoweringPipeline.$anonfun$apply$1$adapted(LoweringPipeline.scala:13)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
	at is.hail.expr.ir.lowering.LoweringPipeline.apply(LoweringPipeline.scala:13)
	at is.hail.expr.ir.CompileAndEvaluate$._apply(CompileAndEvaluate.scala:47)
	at is.hail.backend.spark.SparkBackend._execute(SparkBackend.scala:450)
	at is.hail.backend.spark.SparkBackend.$anonfun$executeEncode$2(SparkBackend.scala:486)
	at is.hail.backend.ExecuteContext$.$anonfun$scoped$3(ExecuteContext.scala:70)
	at is.hail.utils.package$.using(package.scala:635)
	at is.hail.backend.ExecuteContext$.$anonfun$scoped$2(ExecuteContext.scala:70)
	at is.hail.utils.package$.using(package.scala:635)
	at is.hail.annotations.RegionPool$.scoped(RegionPool.scala:17)
	at is.hail.backend.ExecuteContext$.scoped(ExecuteContext.scala:59)
	at is.hail.backend.spark.SparkBackend.withExecuteContext(SparkBackend.scala:339)
	at is.hail.backend.spark.SparkBackend.$anonfun$executeEncode$1(SparkBackend.scala:483)
	at is.hail.utils.ExecutionTimer$.time(ExecutionTimer.scala:52)
	at is.hail.backend.spark.SparkBackend.executeEncode(SparkBackend.scala:482)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:750)



Hail version: 0.2.107-2387bb00ceee
Error summary: SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 4 times, most recent failure: Lost task 0.3 in stage 1.0 (TID 4) (all-of-us-56-w-0.c.terra-vpc-sc-8f5cdfd2.internal executor 4): ExecutorLostFailure (executor 4 exited caused by one of the running tasks) Reason: Container from a bad node: container_1692191509495_0003_01_000004 on host: all-of-us-56-w-0.c.terra-vpc-sc-8f5cdfd2.internal. Exit status: 137. Diagnostics: [2023-08-16 14:47:55.966]Container killed on request. Exit code is 137
[2023-08-16 14:47:55.967]Container exited with a non-zero exit code 137. 
[2023-08-16 14:47:55.968]Killed by external signal
.
Driver stacktrace:

Environment is below
image

Hail log:
hail-20230816-1502-0.2.107-2387bb00ceee.log (4.4 MB)