Getting 'Key Error: va' when using array_agg

I am attempting to use array_agg() and count_where() in tandem and I am running into the following error:

KeyError                                  Traceback (most recent call last)
<ipython-input-5-ada4e7c97a13> in <module>
----> 1 counts = mt_test.aggregate_rows(hl.agg.array_agg(lambda x: hl.agg.counter(x), mt_test.maf_flag))

<decorator-gen-1163> in aggregate_rows(self, expr, _localize)

/opt/conda/default/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
    583     def wrapper(__original_func, *args, **kwargs):
    584         args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 585         return __original_func(*args_, **kwargs_)
    586 
    587     return wrapper

/opt/conda/default/lib/python3.6/site-packages/hail/matrixtable.py in aggregate_rows(self, expr, _localize)
   1987         agg_ir = TableAggregate(MatrixRowsTable(base._mir), subst_query)
   1988         if _localize:
-> 1989             return Env.backend().execute(agg_ir)
   1990         else:
   1991             return construct_expr(agg_ir, expr.dtype)

/opt/conda/default/lib/python3.6/site-packages/hail/backend/backend.py in execute(self, ir, timed)
    107 
    108     def execute(self, ir, timed=False):
--> 109         result = json.loads(Env.hc()._jhc.backend().executeJSON(self._to_java_ir(ir)))
    110         value = ir.typ._from_json(result['value'])
    111         timings = result['timings']

/opt/conda/default/lib/python3.6/site-packages/hail/backend/backend.py in _to_java_ir(self, ir)
    103             r = CSERenderer(stop_at_jir=True)
    104             # FIXME parse should be static
--> 105             ir._jir = ir.parse(r(ir), ir_map=r.jirs)
    106         return ir._jir
    107 

/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in __call__(self, root)
    181 
    182     def __call__(self, root: 'ir.BaseIR') -> str:
--> 183         binding_sites = CSEAnalysisPass(self)(root)
    184         return CSEPrintPass(self)(root, binding_sites)
    185 

/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in __call__(self, root)
    251 
    252             if isinstance(child, ir.IR):
--> 253                 bind_depth = child_frame.bind_depth()
    254                 lets = None
    255                 if bind_depth < len(stack):

/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in bind_depth(self)
    347                 bind_depth = max(bind_depth, max(self.context[0][var] for var in self.node.free_vars))
    348             if len(self.node.free_agg_vars) > 0:
--> 349                 bind_depth = max(bind_depth, max(self.context[1][var] for var in self.node.free_agg_vars))
    350             if len(self.node.free_scan_vars) > 0:
    351                 bind_depth = max(bind_depth, max(self.context[2][var] for var in self.node.free_scan_vars))

/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in <genexpr>(.0)
    347                 bind_depth = max(bind_depth, max(self.context[0][var] for var in self.node.free_vars))
    348             if len(self.node.free_agg_vars) > 0:
--> 349                 bind_depth = max(bind_depth, max(self.context[1][var] for var in self.node.free_agg_vars))
    350             if len(self.node.free_scan_vars) > 0:
    351                 bind_depth = max(bind_depth, max(self.context[2][var] for var in self.node.free_scan_vars))

KeyError: 'va'

Here is the line of code that causes this error:
counts = mt_test.aggregate_rows(hl.agg.array_agg(lambda x: hl.agg.counter(x), mt_test.maf_flag))
Where maf_flag is a row field containing an array of bools. I want to get the counts of true/false aggregated over all rows

Thank you for the help!

This seems like it’s coming from the Python Common Subexpression Elimination code. @patrick-schultz, any ideas?

My first guess would be that the CSE pass is catching a malformed IR. @zkoenig, would you mind trying to pare down your script to something as small as possible that still causes the error? Or even something not too small that you don’t mind sharing.

It took a minute but I was able to pare down the script to the following:

mt_paths = ['vcf/paths']

# Reading in and creating a list of all of the site matrix tables 
mt_list = [hl.import_vcf(mt_path,force_bgz = True) for mt_path in mt_paths]

# Annotating the matrix tables with variant QC data
mt_list = [hl.variant_qc(mt, name = 'variant_qc') for mt in mt_list]

# Joining the matrix tables using union_cols(), will need to then to annotate with row data 
for i in range(len(mt_list)-1):
    if i == 0:
         mt = mt_list[i].union_cols(mt_list[i+1])
    else: 
        mt0 = mt.union_cols(mt_list[i+1])
        mt = mt0

mt = mt.annotate_rows(maf = hl.empty_array('float64'))

#Annotating maf list with mafs
for mt_next in mt_list:
    mt = mt.annotate_rows(maf = mt.maf.append(hl.min(mt_next.index_rows(mt.row_key).variant_qc.AF)))

mt = mt.annotate_rows(maf_flag = hl.empty_array('bool'))

# Filling maf flag list
for mt_next in mt_list:
    mt = mt.annotate_rows(maf_flag = mt.maf_flag.append(hl.min(mt_next.index_rows(mt.row_key).variant_qc.AF) <= 0.005))

counts = mt.aggregate_rows(hl.agg.array_agg(lambda x: hl.agg.counter(x), mt.maf_flag))

Where mt_paths is an array of vcf paths. I didn’t show the path of the data but it reads in 5 vcfs from gcloud and still produces the error.

Bug can be reproduced with this much smaller pipeline:

mt = hl.utils.range_matrix_table(10, 10)
mt = mt.annotate_rows(maf_flag = hl.empty_array('bool'))
counts = mt.aggregate_rows(hl.agg.array_agg(lambda x: hl.agg.counter(x), mt.maf_flag))