Using Spark ML to create and apply Random Forests

The code below is used to compute the Random Forests model to filter gnomAD variants.
A dummy example of the intended usage flow is implemented in run_rf_test(vds, output)


import logging
from hail.expr import *
from pyspark.ml.feature import *
from pyspark.ml.classification import *
from pyspark.ml import *
from pyspark.sql import Row
from pyspark.sql import SparkSession
import pyspark
from pyspark import SparkContext

logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s")
logger = logging.getLogger("RF")
logger.setLevel(logging.INFO)

def run_rf_test(vds):
    vds = vds.annotate_variants_expr('va.train = pcoin(0.9), va.feature1 = pcoin(0.1), va.feature2 = rnorm(0.0, 1.0)')
    vds = vds.annotate_variants_expr('va.label = if(va.feature1 && va.feature2>0) "TP" else "FP"')
    rf_features = ['va.feature1','va.feature2']

    rf_model = train_rf(vds, rf_features)
    save_model(rf_model, out = '/Users/laurent/tmp/rf.model', overwrite=True)
    rf_model = load_model('/Users/laurent/tmp/rf.model')
    return apply_rf_model(vds, rf_model, rf_features)


#Replaces `.` with `_`, since Spark ML doesn't support column names with `.`
def toSSQL(str):
    return str.replace('.','_')


def vds_to_rf_df(vds, rf_features, label='va.label'):

    cols = rf_features + [label]
    kt = vds.split_multi().variants_keytable()

    # Rename everything to avoid problem with dot-delimited paths
    kt = kt.annotate(['%s = %s' % (toSSQL(x), x) for x in cols] +
                     ['variant = str(v)'])
    kt = kt.select([toSSQL(x) for x in cols] + ['variant'])

    # Create dataframe
    # 1) drop rows with missing values (not supported for RF)
    # 2) replace missing labels with standard value since StringIndexer doesn't handle missing values
    df = kt.to_dataframe().dropna(subset=[toSSQL(x) for x in rf_features]).fillna('NA', subset=[toSSQL(label)])

    return df


def get_features_importance(rf_model, rf_index = -2, assembler_index = -3):
    feature_names = [x[:-len("_indexed")] if x.endswith("_indexed") else x for x in rf_model.stages[assembler_index].getInputCols()]
    feature_importance = {toSSQL(new_name): importance for
                          (new_name, importance) in zip(feature_names, rf_model.stages[rf_index].featureImportances)}
    return feature_importance


def get_labels(rf_model):
    return rf_model.stages[0].labels


def apply_rf_model(vds, rf_model, rf_features, root='va.rf', label='va.label'):
    logger.info("Applying RF model to VDS")

    df = vds_to_rf_df(vds, rf_features, label=label)

    feature_importance = get_features_importance(rf_model)

    transformed = rf_model.transform(df)

    logger.info("Annotating dataset with results")

    # Required for RDD.toDF() !
    spark = SparkSession(vds.hc.sc)

    kt = vds.hc.dataframe_to_keytable(
        transformed.rdd.map(
            lambda row:
            Row(variant=row['variant'],
                probability=row["probability"].toArray().tolist(),
                prediction=row["predictedLabel"])
        ).toDF()
    ).persist()

    probability_to_dict_expr = 'probability = index([{%s}], label).mapValues(x => x.prob)' % "},{".join(
        ['label: "%s", prob: probability[%d]' % (l, i) for (i, l) in enumerate(get_labels(rf_model))])

    kt = kt.annotate(['variant = Variant(variant)',
                      probability_to_dict_expr]).key_by('variant')

    vds = vds.annotate_variants_keytable(kt, "%s.prediction = table.prediction, %s.probability = table.probability" % (
    root, root))
    vds = vds.annotate_global_py('global.%s' % (root[3:]), feature_importance, TDict(TString(), TDouble()))

    return vds


def save_model(rf_model, out, overwrite = False):
    logger.info("Saving model to %s" % out)
    if overwrite:
        rf_model.write().overwrite().save(out)
    else:
        rf_model.save(out)


def load_model(input):
    logger.info("Loading model from %s" % input)
    return pyspark.ml.PipelineModel.load(input)


def train_rf(vds, rf_features, training='va.train', label='va.label', num_trees=500, max_depth=5):
    logger.info("Training RF model using:\n"
                "features: %s\n"
                "training: %s\n"
                "labels: %s\n"
                "num_trees: %d\n"
                "max_depth: %d" %( ",".join(rf_features), training, label, num_trees, max_depth ))

    df = vds_to_rf_df(vds, rf_features + [training], label=label)
    df = df.drop('variant')

    SSQL_training = toSSQL(training)
    SSQL_label = toSSQL(label)

    label_indexer = StringIndexer(inputCol=SSQL_label, outputCol=SSQL_label + "_indexed").fit(df)
    labels = label_indexer.labels
    logger.info("Found labels: %s" % labels)

    string_features = [x[0] for x in df.dtypes if x[0] != SSQL_label and x[0] != SSQL_training and x[1] == 'string']
    if string_features:
        logger.info("Indexing string features: %s", ",".join(string_features))
    string_features_indexers = [StringIndexer(inputCol= x, outputCol= x + "_indexed").fit(df)
                                 for x in string_features]

    assembler = VectorAssembler(inputCols= [ x[0] + "_indexed" if x[1] == 'string' else x[0]
                                             for x in df.dtypes if x[0] != SSQL_label and x[0] != SSQL_training],
                                outputCol="features")

    rf = RandomForestClassifier(labelCol=SSQL_label + "_indexed", featuresCol="features",
                                maxDepth=max_depth, numTrees=num_trees)

    label_converter = IndexToString(inputCol='prediction', outputCol='predictedLabel', labels=labels)

    pipeline = Pipeline(stages = [label_indexer] + string_features_indexers +
                                 [assembler, rf, label_converter])

    #rTain model on training sites
    logger.info("Training RF model")
    training_df = df.filter(SSQL_training).drop(SSQL_training)
    rf_model = pipeline.fit(training_df)

    feature_importance = get_features_importance(rf_model)

    logger.info("RF features importance:\n%s" % "\n".join(["%s: %s" % (f,i) for (f,i) in feature_importance.iteritems()]))

    return rf_model

We would like to use the same/similar procedure to flag likely false positive calls
by random forest classifier as it has been used in gnomAD.

We tried to run the example code from May 2017, but failed.
We received the error ‘HailContext’ object has no attribute ‘dataframe_to_keytable’
and a similar error for vds.annotate_variants_keytable.

Would you provide an updated version of the example or even a
more detailed example for hail 0.2?

Tim

The code Laurent posted was for a pre-0.1 version, and it’ll have to change a lot to work with 0.2. I can take a stab at a rough conversion sometime in the next week.