Optimizing processes to work with LD table

Hi Guys,

I’m trying to identify variants that are in LD with (and within a window) around my set variants using the GnomAD population specific LD matrices available from here:

# Block matrix
bmpath = f"gs://gcp-public-data--gnomad/release/{version}/ld/gnomad.genomes.r{version}.{pop}.common.adj.ld.bm"
# Index
indexpath = f"gs://gcp-public-data--gnomad/release/{version}/ld/gnomad.genomes.r{version}.{pop}.common.adj.ld.variant_indices.ht/"

I read the variant index as a hail table, the LD table as a hail BlockMatrix, then square the rectangle. My variants of interest are in a hail table, with the relevant window around them (called region, not sure if that’s required though), I came up with the following process:

## 1. Joining my variants with the ld variant index. It gives the indices in the LD table:
associations_ld_index = (
    top_loci.key_by("locus", "alleles")
    .join(ld_index.key_by("locus", "alleles"))
    .add_index("i")
)

## 2. Joining windows with ld variant index. it gives the indices of ALL variants that are in windows around my variantss.
region_ld_index = (
    ld_index.filter(hl.is_defined(top_loci.key_by("region")[ld_index.locus]))
    .add_index("j")
    .persist()
)

## 3. For each variant in the previously generated region dataset, we get start/end indices for the window
start, stop = hl.linalg.utils.locus_windows(
    region_ld_index.locus, ld_window, _localize=False # <= I kind of want to use localise false, which I assume allows more performant execution
)

## 3b. Collect indices before slicing the block matrix <= is it necessary?
start = start.collect()[0]
stop = stop.collect()[0]

# Get indices for associations and regions:
associations_idxs = associations_ld_index.idx.collect()
region_idxs = region_ld_index.idx.collect()

## 4. Get indices of region variants, that are actually variants of interests:
associations_idx_inregion = [
    idx for idx, value in enumerate(region_idxs) if value in associations_idxs
]

## 5. slicing the LD matrix. 
# When slicing, we sparsify that part of the LD matrix that lies outside of the windows around our variants:
bm_sparse = bm.filter(associations_idxs, region_idxs).sparsify_row_intervals(
    [start[i] for i in associations_idx_inregion],
    [stop[i] for i in associations_idx_inregion],
)

The subsequent part is clear: we get the variant pairs by .entries(), filter them by LD and join with the region_ld_index and association_ld_index tables to get the corresponding variants. However, I have serious concerns with the code I have pasted above (it does the job though): it’s very confusing, I don’t like we need to collect() datasets, which might deteriorates performance (we might have 100k variants to filter from the 10M by 10M matrix). So my question, if there’s a more adequate way to implement this functionality? Is there a way to make it more performant?