I am attempting to use array_agg() and count_where() in tandem and I am running into the following error:
KeyError Traceback (most recent call last)
<ipython-input-5-ada4e7c97a13> in <module>
----> 1 counts = mt_test.aggregate_rows(hl.agg.array_agg(lambda x: hl.agg.counter(x), mt_test.maf_flag))
<decorator-gen-1163> in aggregate_rows(self, expr, _localize)
/opt/conda/default/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
583 def wrapper(__original_func, *args, **kwargs):
584 args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 585 return __original_func(*args_, **kwargs_)
586
587 return wrapper
/opt/conda/default/lib/python3.6/site-packages/hail/matrixtable.py in aggregate_rows(self, expr, _localize)
1987 agg_ir = TableAggregate(MatrixRowsTable(base._mir), subst_query)
1988 if _localize:
-> 1989 return Env.backend().execute(agg_ir)
1990 else:
1991 return construct_expr(agg_ir, expr.dtype)
/opt/conda/default/lib/python3.6/site-packages/hail/backend/backend.py in execute(self, ir, timed)
107
108 def execute(self, ir, timed=False):
--> 109 result = json.loads(Env.hc()._jhc.backend().executeJSON(self._to_java_ir(ir)))
110 value = ir.typ._from_json(result['value'])
111 timings = result['timings']
/opt/conda/default/lib/python3.6/site-packages/hail/backend/backend.py in _to_java_ir(self, ir)
103 r = CSERenderer(stop_at_jir=True)
104 # FIXME parse should be static
--> 105 ir._jir = ir.parse(r(ir), ir_map=r.jirs)
106 return ir._jir
107
/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in __call__(self, root)
181
182 def __call__(self, root: 'ir.BaseIR') -> str:
--> 183 binding_sites = CSEAnalysisPass(self)(root)
184 return CSEPrintPass(self)(root, binding_sites)
185
/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in __call__(self, root)
251
252 if isinstance(child, ir.IR):
--> 253 bind_depth = child_frame.bind_depth()
254 lets = None
255 if bind_depth < len(stack):
/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in bind_depth(self)
347 bind_depth = max(bind_depth, max(self.context[0][var] for var in self.node.free_vars))
348 if len(self.node.free_agg_vars) > 0:
--> 349 bind_depth = max(bind_depth, max(self.context[1][var] for var in self.node.free_agg_vars))
350 if len(self.node.free_scan_vars) > 0:
351 bind_depth = max(bind_depth, max(self.context[2][var] for var in self.node.free_scan_vars))
/opt/conda/default/lib/python3.6/site-packages/hail/ir/renderer.py in <genexpr>(.0)
347 bind_depth = max(bind_depth, max(self.context[0][var] for var in self.node.free_vars))
348 if len(self.node.free_agg_vars) > 0:
--> 349 bind_depth = max(bind_depth, max(self.context[1][var] for var in self.node.free_agg_vars))
350 if len(self.node.free_scan_vars) > 0:
351 bind_depth = max(bind_depth, max(self.context[2][var] for var in self.node.free_scan_vars))
KeyError: 'va'
Here is the line of code that causes this error:
counts = mt_test.aggregate_rows(hl.agg.array_agg(lambda x: hl.agg.counter(x), mt_test.maf_flag))
Where maf_flag is a row field containing an array of bools. I want to get the counts of true/false aggregated over all rows
Thank you for the help!