diff --git a/openproblems/tasks/label_projection/methods/__init__.py b/openproblems/tasks/label_projection/methods/__init__.py index a5b099a5ae..1330b09e12 100644 --- a/openproblems/tasks/label_projection/methods/__init__.py +++ b/openproblems/tasks/label_projection/methods/__init__.py @@ -10,5 +10,7 @@ from .scvi_tools import scanvi_hvg from .scvi_tools import scarches_scanvi_all_genes from .scvi_tools import scarches_scanvi_hvg +from .scvi_tools import scarches_scanvi_xgb_all_genes +from .scvi_tools import scarches_scanvi_xgb_hvg from .xgboost import xgboost_log_cpm from .xgboost import xgboost_scran diff --git a/openproblems/tasks/label_projection/methods/scvi_tools.py b/openproblems/tasks/label_projection/methods/scvi_tools.py index 05cd51825d..437c588019 100644 --- a/openproblems/tasks/label_projection/methods/scvi_tools.py +++ b/openproblems/tasks/label_projection/methods/scvi_tools.py @@ -86,7 +86,14 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): return preds -def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): +def _scanvi_scarches( + adata, + test=False, + n_hidden=None, + n_latent=None, + n_layers=None, + prediction_method="scanvi", +): import scvi if test: @@ -138,6 +145,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N train_kwargs["limit_val_batches"] = 10 query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs) + if prediction_method == "scanvi": + preds = _pred_scanvi(adata, query_model) + elif prediction_method == "xgboost": + preds = _pred_xgb(adata, adata_train, adata_test, query_model, test=test) + + return preds + + +def _pred_scanvi(adata, query_model): # this is temporary and won't be used adata.obs["scanvi_labels"] = "Unknown" preds = query_model.predict(adata) @@ -146,6 +162,54 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N return preds +# note: add test option here +def _pred_xgb( + adata, + adata_train, + adata_test, + query_model, + label_col="labels", + test=False, + num_round: Optional[int] = None, +): + import xgboost as xgb + + df = _classif_df(adata_train, query_model, label_col) + + X_train = df.drop(columns="labels") + y_train = df["labels"] + + X_test = query_model.get_latent_representation(adata_test) + + if test: + num_round = num_round or 2 + else: + num_round = num_round or 5 + + xgbc = xgb.XGBClassifier(tree_method="hist", objective="multi:softprob") + + xgbc.fit(X_train, y_train) + + adata_test.obs["preds_test"] = xgbc.predict(X_test) + + preds = [ + adata_test.obs["preds_test"][idx] if idx in adata_test.obs_names else np.nan + for idx in adata.obs_names + ] + + return preds + + +def _classif_df(adata, trained_model, label_col): + emb_data = trained_model.get_latent_representation(adata) + + df = pd.DataFrame(data=emb_data, index=adata.obs_names) + + df["labels"] = adata.obs[label_col] + + return df + + @_scanvi_method(method_name="scANVI (All genes)") def scanvi_all_genes(adata, test=False): adata.obs["labels_pred"] = _scanvi(adata, test=test) @@ -176,3 +240,25 @@ def scarches_scanvi_hvg(adata, test=False): adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test) adata.uns["method_code_version"] = check_version("scvi-tools") return adata + + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)") +def scarches_scanvi_all_genes(adata, test=False): + adata.obs["labels_pred"] = _scanvi_scarches( + adata, test=test, prediction_method="xgboost" + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata + + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)") +def scarches_scanvi_hvg(adata, test=False): + hvg_df = _hvg(adata, test) + bdata = adata[:, hvg_df.highly_variable].copy() + adata.obs["labels_pred"] = _scanvi_scarches( + bdata, test=test, prediction_method="xgboost" + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata