for now you can also get down to O(N * log2 N) doing a tree union:
mts_ = mts[:]
iteration = 0
while (len(mts_) > 1):
iteration += 1
print(f'iteration {iteration}')
tmp = []
for i in range(0, len(mts_), 2):
tmp.append(hl.MatrixTable.union_cols(mts_[i:i+2]))
mts_ = tmp[:]
[final_mt] = mts_