Annotate_cols out of memory issues

Hi hail team,

I’m running into this error relatively frequently in my notebooks today:

ERROR (root 1003): Exception while sending command.
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1159, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 985, in send_command
    response = connection.send_command(command)
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1164, in send_command
    "Error while receiving", e, proto.ERROR_ON_RECEIVE)
py4j.protocol.Py4JNetworkError: Error while receiving

and the rest of the error:

---------------------------------------------------------------------------
Py4JError                                 Traceback (most recent call last)
<ipython-input-15-4533291abd08> in <module>
      2 plots = []
      3 for b in batches:
----> 4     p = hl.plot.histogram(chr20_agg[f'n_not_called_{b}'], title=f'n_not_called {b}')
      5     plots.append(Panel(child=p, title=f'{b}'))
      6 show(Tabs(tabs=plots))

</opt/conda/default/lib/python3.6/site-packages/decorator.py:decorator-gen-1543> in histogram(data, range, bins, legend, title, log, interactive)

/opt/conda/default/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
    583     def wrapper(__original_func, *args, **kwargs):
    584         args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 585         return __original_func(*args_, **kwargs_)
    586 
    587     return wrapper

/opt/conda/default/lib/python3.6/site-packages/hail/plot/plots.py in histogram(data, range, bins, legend, title, log, interactive)
    390                 finite_data = hail.bind(lambda x: hail.case().when(hail.is_finite(x), x).or_missing(), data)
    391                 start, end = agg_f((aggregators.min(finite_data),
--> 392                                     aggregators.max(finite_data)))
    393                 if start is None and end is None:
    394                     raise ValueError(f"'data' contains no values that are defined and finite")

</opt/conda/default/lib/python3.6/site-packages/decorator.py:decorator-gen-1155> in aggregate_rows(self, expr, _localize)

/opt/conda/default/lib/python3.6/site-packages/hail/typecheck/check.py in wrapper(__original_func, *args, **kwargs)
    583     def wrapper(__original_func, *args, **kwargs):
    584         args_, kwargs_ = check_all(__original_func, args, kwargs, checkers, is_method=is_method)
--> 585         return __original_func(*args_, **kwargs_)
    586 
    587     return wrapper

/opt/conda/default/lib/python3.6/site-packages/hail/matrixtable.py in aggregate_rows(self, expr, _localize)
   1987         agg_ir = TableAggregate(MatrixRowsTable(base._mir), subst_query)
   1988         if _localize:
-> 1989             return Env.backend().execute(agg_ir)
   1990         else:
   1991             return construct_expr(agg_ir, expr.dtype)

/opt/conda/default/lib/python3.6/site-packages/hail/backend/backend.py in execute(self, ir, timed)
    107 
    108     def execute(self, ir, timed=False):
--> 109         result = json.loads(Env.hc()._jhc.backend().executeJSON(self._to_java_ir(ir)))
    110         value = ir.typ._from_json(result['value'])
    111         timings = result['timings']

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/opt/conda/default/lib/python3.6/site-packages/hail/utils/java.py in deco(*args, **kwargs)
    211         import pyspark
    212         try:
--> 213             return f(*args, **kwargs)
    214         except py4j.protocol.Py4JJavaError as e:
    215             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    334             raise Py4JError(
    335                 "An error occurred while calling {0}{1}{2}".
--> 336                 format(target_id, ".", name))
    337     else:
    338         type = answer[1]

Py4JError: An error occurred while calling o141.executeJSON

Most recently, I ran into this error trying to create some plots (hl.plot.histogram). Any ideas what’s going on, or how I can fix this?

The real error message is in the hail log. Can you attach that here and we’ll take a look?

java_error.log (1.5 MB)

You appear to have run out of memory (at least 32 GB in use). What pipeline triggered this? Something with a lot of column aggregations?

Ah. It’s the same code as in Strange keyerror (KeyError: 'g'), swapping in

chr20_agg = chr20_agg.annotate_rows(
    n_not_called_50K=hl.agg.filter(chr20_agg.batch == '150K', 
        hl.agg.take(chr20_agg.n_not_called, 1)),
    n_not_called_100K=hl.agg.filter(chr20_agg.batch == '100K', 
        hl.agg.take(chr20_agg.n_not_called, 1)),
    n_not_called_200K=hl.agg.filter(chr20_agg.batch == '200K', 
        hl.agg.take(chr20_agg.n_not_called, 1))
)

and

chr20_agg = chr20_agg.rows()
chr20_agg = chr20_agg.explode(chr20_agg.n_not_called_50K)
chr20_agg = chr20_agg.explode(chr20_agg.n_not_called_100K)
chr20_agg = chr20_agg.explode(chr20_agg.n_not_called_200K)
chr20_agg.describe()

batches = ['50K', '100K', '200K']
plots = []
for b in batches:
    p = hl.plot.histogram(chr20_agg[f'n_not_called_{b}'], title=f'n_not_called {b}')
    plots.append(Panel(child=p, title=f'{b}'))
show(Tabs(tabs=plots))

I don’t totally follow, but, from the other post:

chr20 = chr20.annotate_rows(
    n_not_called_50K=chr20.aggregate_cols(
        hl.agg.filter(chr20.batch == '150K', 
        hl.agg.count_where(hl.is_missing(chr20.GT)))),
    n_not_called_100K=chr20.aggregate_cols(
        hl.agg.filter(chr20.batch == '100K', 
        hl.agg.count_where(hl.is_missing(chr20.GT)))),
    n_not_called_200K=chr20.aggregate_cols(
        hl.agg.filter(chr20.batch == '200K', 
        hl.agg.count_where(hl.is_missing(chr20.GT))))
)

I wouldn’t expect this to run? The following aggregates over column fields to produce a python value:

chr20.aggregate_cols(
        hl.agg.filter(chr20.batch == '200K', 
        hl.agg.count_where(hl.is_missing(chr20.GT)))

But GT is an entry field, not a column field. That explains your key error I guess, but that issue ought to be caught earlier.

I wouldn’t expect the code you’ve posted here to cause a memory overflow on the leader node (what I believe you’re encountering).

Anyway, I feel like you’re looking for this?

chr20 = chr20.annotate_rows(
    n_not_called_50K=
        hl.agg.filter(chr20.batch == '150K', 
        hl.agg.count_where(hl.is_missing(chr20.GT))),
    n_not_called_100K=
        hl.agg.filter(chr20.batch == '100K', 
        hl.agg.count_where(hl.is_missing(chr20.GT))),
    n_not_called_200K=
        hl.agg.filter(chr20.batch == '200K', 
        hl.agg.count_where(hl.is_missing(chr20.GT)))
)

(removing the aggregate_cols)

Ooh, I see. And that is what I was trying to do, thanks! Unfortunately, running

chr20 = chr20.annotate_rows(
    n_not_called_50K=
        hl.agg.filter(chr20.batch == '150K', 
        hl.agg.count_where(hl.is_missing(chr20.GT))),
    n_not_called_100K=
        hl.agg.filter(chr20.batch == '100K', 
        hl.agg.count_where(hl.is_missing(chr20.GT))),
    n_not_called_200K=
        hl.agg.filter(chr20.batch == '200K', 
        hl.agg.count_where(hl.is_missing(chr20.GT)))
)

and

chr20 = chr20.rows()
chr20.show(5)

gave me the same Java error. Is the out of memory issue with the master or worker nodes?

Can you share the log for that failure? Probably something still too big. Best guess is something related to too much column data.

java_error.log (1.6 MB)

I tried running the same thing in a submit script too test.log (1.6 MB)

How many bytes are these two tables:

sample_map_ht = hl.read_table(array_sample_map_ht(data_source, freeze))
sample_map = hl.import_table(array_sample_map(freeze), delimiter=',', quote='"')

? That’s the only thing I have left to guess at. It looks like they must be many 10s of GB in Hail’s internal representation, which either means they’re huge or there’s a Hail issue.

sample_map_ht is 5.63 MB, and sample_map is 4.09 MB

OK, I think this will have to wait until tomorrow when someone from the compiler team can look at it in detail. cc: @tpoterba

I vaguely remember somebody reporting similar unexpected OOMs with column annotation a month or two ago on Zulip. First order of business is replicating this locally.

am I right that the following triggers the OOM?

hardcalls = get_ukbb_data(data_source, freeze, raw=False, split=True, adj=False)

sample_map_ht = hl.read_table(array_sample_map_ht(data_source, freeze))
sample_map = hl.import_table(array_sample_map(freeze), delimiter=',', quote='"')
sample_map = sample_map.key_by(s=sample_map.eid_26041)

print(hardcalls.count())
chr20 = hl.filter_intervals(hardcalls, [hl.parse_locus_interval('chr20', reference_genome='GRCh38')])
print(chr20.count())

chr20 = chr20.select_rows('a_index', 'was_split')

chr20 = chr20.annotate_cols(**sample_map_ht[chr20.s])
chr20 = chr20.annotate_cols(**sample_map[chr20.ukbb_app_26041_id])

chr20 = chr20.select_cols('batch', 'batch.c')
chr20 = chr20.transmute_cols(batch_num=chr20['batch'],
                             batch=chr20['batch.c'])
chr20.cols()._force_count()

yes – I just ran this on 0.2.29, and it crashed with the same error as above

how about this:

hardcalls = get_ukbb_data(data_source, freeze, raw=False, split=True, adj=False)

sample_map_ht = hl.read_table(array_sample_map_ht(data_source, freeze))
sample_map = hl.import_table(array_sample_map(freeze), delimiter=',', quote='"')
sample_map = sample_map.key_by(s=sample_map.eid_26041)

chr20 = chr20.annotate_cols(**sample_map_ht[chr20.s])
chr20 = chr20.annotate_cols(**sample_map[chr20.ukbb_app_26041_id])

chr20 = chr20.select_cols('batch', 'batch.c')
chr20 = chr20.transmute_cols(batch_num=chr20['batch'],
                             batch=chr20['batch.c'])
chr20.cols()._force_count()

I think I also need to see what get_ukbb_data is doing.

def get_ukbb_data(data_source: str, freeze: int = CURRENT_FREEZE, adj: bool = False, split: bool = True,
                  raw: bool = False, non_refs_only: bool = False, meta_root: Optional[str] = None) -> hl.MatrixTable:
    from gnomad_hail.utils import filter_to_adj

    if raw and split:
        raise DataException('No split raw data. Use of hardcalls is recommended.')

    if non_refs_only:
        mt = hl.read_matrix_table(get_ukbb_data_path(data_source, freeze, split=split, non_refs_only=non_refs_only))
    else:
        mt = hl.read_matrix_table(get_ukbb_data_path(data_source, freeze, hardcalls=not raw, split=split))

    if adj:
        mt = filter_to_adj(mt)

    if meta_root:
        meta_ht = hl.read_table(meta_ht_path(data_source, freeze))
        mt = mt.annotate_cols(**{meta_root: meta_ht[mt.s]})

    return mt

OK, cool, so this is just a read_matrix_table.

I’ll see if I can replicate.