diff --git a/openproblems/tasks/label_projection/methods/__init__.py b/openproblems/tasks/label_projection/methods/__init__.py index 40d35a403c..7899498aea 100644 --- a/openproblems/tasks/label_projection/methods/__init__.py +++ b/openproblems/tasks/label_projection/methods/__init__.py @@ -10,6 +10,8 @@ from .scvi_tools import scanvi_hvg from .scvi_tools import scarches_scanvi_all_genes from .scvi_tools import scarches_scanvi_hvg +from .scvi_tools import scarches_scanvi_xgb_all_genes +from .scvi_tools import scarches_scanvi_xgb_hvg from .seurat import seurat from .xgboost import xgboost_log_cpm from .xgboost import xgboost_scran diff --git a/openproblems/tasks/label_projection/methods/scvi_tools.py b/openproblems/tasks/label_projection/methods/scvi_tools.py index 6212f8cd69..869dd85c78 100644 --- a/openproblems/tasks/label_projection/methods/scvi_tools.py +++ b/openproblems/tasks/label_projection/methods/scvi_tools.py @@ -1,5 +1,6 @@ from ....tools.decorators import method from ....tools.utils import check_version +from typing import Optional import functools @@ -86,7 +87,14 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): return preds -def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): +def _scanvi_scarches( + adata, + test=False, + n_hidden=None, + n_latent=None, + n_layers=None, + prediction_method="scanvi", +): import scvi if test: @@ -138,6 +146,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N train_kwargs["limit_val_batches"] = 10 query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs) + if prediction_method == "scanvi": + preds = _pred_scanvi(adata, query_model) + elif prediction_method == "xgboost": + preds = _pred_xgb(adata, adata_train, adata_test, query_model, test=test) + + return preds + + +def _pred_scanvi(adata, query_model): # this is temporary and won't be used adata.obs["scanvi_labels"] = "Unknown" preds = query_model.predict(adata) @@ -146,6 +163,63 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N return preds +# note: could extend test option +def _pred_xgb( + adata, + adata_train, + adata_test, + query_model, + label_col="labels", + test=False, + num_round: Optional[int] = None, +): + import numpy as np + import xgboost as xgb + + df = _classif_df(adata_train, query_model, label_col) + + df["labels_int"] = df["labels"].cat.codes + categories = df["labels"].cat.categories + + # X_train = df.drop(columns="labels") + X_train = df.drop(columns=["labels", "labels_int"]) + # y_train = df["labels"].astype("category") + y_train = df["labels_int"].astype(int) + + X_test = query_model.get_latent_representation(adata_test) + + if test: + num_round = num_round or 2 + else: + num_round = num_round or 5 + + xgbc = xgb.XGBClassifier(tree_method="hist", objective="multi:softprob") + + xgbc.fit(X_train, y_train) + + # adata_test.obs["preds_test"] = xgbc.predict(X_test) + adata_test.obs["preds_test"] = categories[xgbc.predict(X_test)] + + preds = [ + adata_test.obs["preds_test"][idx] if idx in adata_test.obs_names else np.nan + for idx in adata.obs_names + ] + + return preds + + +def _classif_df(adata, trained_model, label_col): + import pandas as pd + + emb_data = trained_model.get_latent_representation(adata) + + df = pd.DataFrame(data=emb_data, index=adata.obs_names) + + df["labels"] = adata.obs[label_col] + + return df + + @_scanvi_method(method_name="scANVI (All genes)") def scanvi_all_genes(adata, test=False): adata.obs["labels_pred"] = _scanvi(adata, test=test) @@ -176,3 +250,25 @@ def scarches_scanvi_hvg(adata, test=False): adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test) adata.uns["method_code_version"] = check_version("scvi-tools") return adata + + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)") +def scarches_scanvi_xgb_all_genes(adata, test=False): + adata.obs["labels_pred"] = _scanvi_scarches( + adata, test=test, prediction_method="xgboost" + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata + + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)") +def scarches_scanvi_xgb_hvg(adata, test=False): + hvg_df = _hvg(adata, test) + bdata = adata[:, hvg_df.highly_variable].copy() + adata.obs["labels_pred"] = _scanvi_scarches( + bdata, test=test, prediction_method="xgboost" + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata