Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
20c782b
making paths anonymous
hagerrady13 Jun 25, 2025
d95e971
improvements/renames
hagerrady13 Jun 30, 2025
9e3617d
add figure 5 notebook
melisandeteng Jun 29, 2025
1d83987
remove un-used code
hagerrady13 Jun 30, 2025
345d618
adding data preparation code
hagerrady13 Jun 30, 2025
f90c1c6
clean run_files
hagerrady13 Jun 30, 2025
07deadd
add requirements file
hagerrady13 Jun 30, 2025
fe9716c
update .gitignore
hagerrady13 Jun 30, 2025
4f06f41
update .gitignore
hagerrady13 Jun 30, 2025
00cae14
Merge remote-tracking branch 'origin/release' into release
hagerrady13 Jun 30, 2025
fbb9014
adding README
hagerrady13 Jun 30, 2025
3d20cba
Update README.md
zbirobin Jul 1, 2025
dc19d9f
Upload figures 1 and 2.
zbirobin Jul 1, 2025
c2b8156
Add figures 1 and 2 in png.
zbirobin Jul 1, 2025
0cbf242
Add Figure 1 to the README.
zbirobin Jul 1, 2025
5e25835
updating README - reproducing results
hagerrady13 Jul 1, 2025
51bdf6c
initial
hagerrady13 Nov 21, 2025
f9ef1b2
add random forest baseline
hagerrady13 Nov 26, 2025
201d6b0
update fig 5 notebook
melisandeteng Dec 1, 2025
6a4b288
update fig 5 notebook
melisandeteng Dec 1, 2025
fc969b8
tutorial added
hagerrady13 Dec 2, 2025
86cd691
multi runs of RF on SatBird
hagerrady13 Dec 3, 2025
09e687c
cleaning outputs
hagerrady13 Dec 3, 2025
d3da001
fixes
hagerrady13 Dec 3, 2025
c57e9b5
Merge pull request #26 from RolnickLab/tutorial
hagerrady13 Dec 3, 2025
aa5f49d
add sjsdm
melisandeteng Dec 8, 2025
c999f3e
Update README.md
melisandeteng Dec 8, 2025
5dd58ac
add data prep
melisandeteng Dec 8, 2025
38bc941
Merge branch 'sjsdm' of https://github.com/RolnickLab/SDMPartialLabel…
melisandeteng Dec 8, 2025
180ce3d
Update README.md
melisandeteng Dec 8, 2025
0179105
Update sjsdm-splot.py
melisandeteng Dec 8, 2025
621a59f
Merge pull request #27 from RolnickLab/sjsdm
hagerrady13 Dec 9, 2025
24028d3
updated scripts for random forest
hagerrady13 Dec 9, 2025
140f9a7
removing paths
hagerrady13 Dec 9, 2025
b3090b5
removing paths
hagerrady13 Dec 9, 2025
09c5932
cleaning
hagerrady13 Dec 9, 2025
e07564a
Merge pull request #28 from RolnickLab/random_forest
hagerrady13 Dec 9, 2025
edb6f0b
Update README with additional baselines information
hagerrady13 Dec 9, 2025
8caf375
Figure 3 with the histogram of difference.
zbirobin Dec 23, 2025
f7fb7d9
Add Figure 6 varying size of training data.
zbirobin Dec 23, 2025
c3dba0a
Add figure 10, plotting results of Table 1.
zbirobin Dec 23, 2025
a062a5d
Merge pull request #29 from RolnickLab/robin/figures
hagerrady13 Dec 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
*.ipynb_checkpoints
*.pyc
*.DS_Store
__pycache__/
1337/epoch=11-step=623.ckpt
1337/*
4017/*
outputs/*
2676/*
job_*
checkpoints*
data/
model_checkpoints/
data/
experiments_outputs/
197 changes: 164 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,50 +1,181 @@
# SDM with Partial Labels
# CISO-SDM 🐦🦋🌿

### Evaluation:
We use the parameter `predict_family_of_species` to control which family subset of species we are evaluating
### Species within a single taxonomy setup:
SatBird:
- `predict_family_of_species = 0` : evaluate non-songbirds
- `predict_family_of_species = 1` : evaluate songbirds
This repository contains the code to reproduce the results from the paper:

splot:
- `predict_family_of_species = 0` : evaluate non-trees
- `predict_family_of_species = 1` : evaluate trees
### CISO: Species Distribution Modeling Conditioned on Incomplete Species Observations

### Species in Multi-taxa setup:
SatBird & SatButterfly:
- To evaluate with **no partial labels** given (everything is unknown), set `eval_known_rate == 0 `
- `predict_family_of_species = 0` : evaluate birds
- `predict_family_of_species = 1` : evaluate butterflies

- To evaluate with **partial labels** given (some labels known), set `eval_known_rate == 1 `
- `predict_family_of_species = 0` : evaluate birds
- `predict_family_of_species = 1` : evaluate butterflies
<br>

<div align="left">
<img src="figures/overview_dataset.png" alt="Figure" width="780"/>
</div>

### Running code:
<br>

#### Installation
Code runs on Python 3.10. You can create conda env using `requirements/environment.yaml` or install pip packages from `requirements/requirements.txt`
## 🛠️ Installation

We recommend following these steps for installing the required packages:
### ⚙️ Requirements

```conda env create -f requirements/environment.yaml```
This project requires **Python 3.11**. All dependencies are listed in `requirements/requirements.txt`.

```conda activate satbird```
We recommend using **Conda** to create and manage a clean environment:

```conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia```
```bash
conda create -n ciso_env python=3.11
conda activate ciso_env
pip install -r requirements/requirements.txt
```

#### Training and testing
#### 💡Not using Conda?

* To train the model (check `run_files/job.sh`) : `python train.py args.config=configs/base.yaml`. Examples of all config files for different baselines
are available in `configs`.
* To train a model: `python train.py args.config=$CONFIG_FILE_NAME `
* To test a model: `python test.py args.config=$CONFIG_FILE_NAME `
You can also use a standard Python virtual environment:

To log experiments on comet-ml, make sure you have exported your COMET_API_KEY and COMET_WORKSPACE in your environmental variables.
You can do so with `export COMET_API_KEY=your_comet_api_key` in your terminal.
```bash
python3.11 -m venv ciso_env
source ciso_env/bin/activate # On Windows use `ciso_env\Scripts\activate`
pip install -r requirements/requirements.txt
```

### 🖥️ Hardware Support

The results can be reproduced on any device (GPU, eGPU, or CPU), though the computational time will vary depending on the hardware's parallel processing capabilities. If you encounter memory issues, especially on lower-end devices, consider reducing the batch size in the training configuration to mitigate them. You may also need to install or update the appropriate NVIDIA drivers to work with PyTorch, depending on your specific setup.

## 📂 Datasets

All datasets used in the paper are publicly available [here](https://huggingface.co/cisosdm/datasets) on Hugging Face.

#### Data preparation scripts

Data preparation code is located in the `data_preprocessing/` folder:

* SatButterfly: `ebutterfly_data_preparation.ipynb`
* sPlotOpen: `prepare_sPlotOpen_data.ipynb`
* SatBird × sPlotOpen (co-located data): `prepare_satbirdxsplots.ipynb`

## 🔬 Experiment configurations:

The `configs` directory contains subfolders for each dataset setup:

* `configs/satbird`
* `configs/satbirdxsatbutterfly`
* `configs/satbirdxsplot`
* `configs/splot`

Each subfolder includes YAML config files for models reported in the paper. These configs are used for both training and evaluation.

| File | Model Description |
|----------------------------|-----------------------------------|
| `config_ciso.yaml` | CISO model |
| `config_linear.yaml` | Linear model |
| `config_maxent.yaml` | Linear model with MaxEnt features |
| `config_mlp.yaml` | MLP |
| `config_mlp_plusplus.yaml` | MLP++ |

## 🤖 Trained model checkpoints:

All trained model checkpoints are available [here](https://huggingface.co/cisosdm/model_checkpoints).
For each dataset, each folder includes 3 sub-folders corresponding to 3 different runs (seeds).

| Folder | Description |
|--------------------------|----------------------------------------------------------|
| `1_sPlotOpen` | Within-dataset experiments for sPlotOpen |
| `2_SatBird` | Within-dataset experiments for SatBird |
| `3_SatBirdxSatButterfly` | Across-datasets experiments for SatBird and SatButterfly |
| `4_SatBirdxsPlotOpen` | Across-datasets experiments for SatBird and sPlotOpen |

## 🚀 Running code

### 🔹 Training
[Optional] You can log experiments using [Comet ML](https://www.comet.com/site/), a platform for visualizing metrics as your models are training. To enable logging, make sure to export your `COMET_API_KEY` and `COMET_WORKSPACE` environment variables:

```bash
export COMET_API_KEY=your_comet_api_key
export COMET_WORKSPACE=your_workspace
```

To train a model, set `config.mode = "train"` ,and run:

```bash
python main.py --config=configs/<dataset>/<model_config>.yaml
```

Examples of configuration files for different datasets and models can be found in the `configs/` directory.

### 🔹 Evaluation

Use the `predict_family_of_species` parameter to control the subset of species evaluated. This parameter defaults to`-1` during training (i.e., not used).

#### Evaluating species groups *within* a dataset

🐦 **SatBird**:
- `predict_family_of_species = 0` → evaluate **non-songbirds**
- `predict_family_of_species = 1` → evaluate **songbirds**

🌿 **sPlotOpen**:
- `predict_family_of_species = 0` → evaluate **non-trees**
- `predict_family_of_species = 1` → evaluate **trees**


#### Evaluating species groups *across* datasets
🐦🦋 **SatBird & SatButterfly**:
- `predict_family_of_species = 0` → evaluate **birds**
- `predict_family_of_species = 1` → evaluate **butterflies**

🐦🌿**SatBird & sPlotOpen**:
- `predict_family_of_species = 0` → evaluate **plants**
- `predict_family_of_species = 1` → evaluate **birds**


#### Conditioning on other species groups

For models that support partial labels (i.e., CISO and MLP++), evaluation can be conditioned on the observations of species from another group:

- `partial_labels.eval_known_rate == 0 ` → evaluate with no partial labels (all other species groups are unknown)
- `partial_labels.eval_known_rate == 1 ` → evaluate with partial labels (labels from the other species group are
provided)


## 📊 Reproducing Results and Figures

You can reproduce the key results and figures from the paper using the scripts and notebooks provided below:

### 📈 Results

To reproduce the results on Tables 1 & 2, as well as additional metrics in Table 5, you first download
the `model_checkpoints` folder from [here](https://huggingface.co/cisosdm/model_checkpoints).
For a certain dataset and model, you run the following command given the corresponding config file, and specify the
desired file name to save results.
For the config file specified, you need to control the following parameters:

| Parameter | Description |
|-----------------------------------|-----------------------------------------------------------------------------|
| `load_ckpt_path` | Path to the checkpoints folder or exact checkpoint path. |
| `predict_family_of_species` | Controls which family of species to evaluate as shown above in Evaluation. |
| `partial_labels/eval_known_ratio` | Set to 1, to condition on known labels for the other group of species. |

Set `config.mode = "test"`
```bash
python main.py --config=configs/<dataset>/<model_config>.yaml --results_file_name=<results_file_name.csv>
```

##### Example:

To reproduce results for CISO on SatBird, use `configs/satbird/config_ciso.yaml`. Set `load_ckpt_path`
to `model_checkpoints/2_SatBird/satbird_ciso` to evaluate the 3 different runs.

* Evaluate **songbirds Unconditioned**: `predict_family_of_species = 1`.
* Evaluate **non-songbirds Unconditioned**: `predict_family_of_species = 0`.
* Evaluate **songbirds** **Conditioned** on non-songbirds: `predict_family_of_species = 1` and `eval_known_ratio = 1`.
* Evaluate **non-songbirds** **Conditioned** on songbirds: `predict_family_of_species = 0` and `eval_known_ratio = 1`.

### 🖼️ Figures
* Figure 3: `figures/generate_figure_3.ipynb`
* Figure 4: `figures/generate_figure_4.ipynb`
* Figure 5: `figures/generate_figure_5.ipynb`

### Revision 1:
Code for additional baselines such as Random Forest and sjSDM is available under `src/models`.

## 📜 License
This work is licensed under a
[Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) License](https://creativecommons.org/licenses/by-nc/4.0/).
24 changes: 0 additions & 24 deletions configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,39 +75,15 @@ losses:
criterion: "CE" #or MAE or MSE (loss to choosefor optim )

metrics:
- name: ce
ignore: True
#weights on the cross entropy
lambd_pres: 1
lambd_abs: 1
scale : 1
- name: mae
ignore: False
scale: 10
- name: nonzero_mae
ignore: True
scale: 10
- name: mse
ignore: False
scale: 10
- name: nonzero_mse
ignore: True
scale: 10
- name: topk
ignore: False
scale: 1
- name: topk2
ignore: True
scale: 1
- name: r2
ignore: True #False
scale: 1
- name: kl
ignore: True
scale : 1
- name: accuracy
ignore: True
scale: 1
- name: top10
ignore: False
scale: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
#where to save checkpoints
save_path: "multirun_experiments/satbird_ctran"
save_path: "model_checkpoints/2_SatBird/satbird_ciso"

# load existing checkpoint for inference. If passing experiment folder instead (for multiple seeds), it will evaluate all of them.
# always use the best checkpoint
load_ckpt_path: "multirun_experiments/satbird_ctran"
save_preds_path: "" #"/network/scratch/h/hager.radi/ecosystem-embedding/baseline_resnet18_RGBNIR_ENV/preds_path"
load_ckpt_path: "model_checkpoints/2_SatBird/satbird_ciso"
save_preds_path: ""

dataloader_to_use: "SDMEnvMaskedDataset"

comet:
project_name: "SDMPartialLabels"
tags: ["Ctran", "corrected_targets", "satbird"]
experiment_name: "satbird_ctran" # specify for training, or use to report test results, TODO: also use to resume training
experiment_name: "satbird_ciso" # specify for training, or use to report test results,
experiment_key: "" # use to report test results,

model:
name: "CTranModel"
name: "CISOModel"
input_dim: 27
hidden_dim: 256
backbone: "SimpleMLPBackbone"
Expand All @@ -35,8 +35,6 @@ losses:

partial_labels:
use: true
# mask known labels out of the loss (true or false)
masked_loss: False
# quantized mask (1 if all positives to 1, > 1 to indicate bins)
quantized_mask_bins: 4
# max ratio of unknown labels during training
Expand Down Expand Up @@ -64,7 +62,7 @@ data:
'phihox', 'sltppt', 'sndppt']

files:
base: "/network/projects/ecosystem-embeddings/SatBird_data_v2/USA_summer"
base: "data/SatBird_data_v2/USA_summer"
train: ["train_split.csv"]
val: ["valid_split.csv"]
test: ["test_split.csv"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#where to save checkpoints
save_path: "multirun_experiments/satbird_linear_v1"
save_path: "model_checkpoints/2_SatBird/satbird_linear"
# load existing checkpoint for inference. If passing experiment folder instead (for multiple seeds), it will evaluate all of them.
# always use the best checkpoint
load_ckpt_path: "multirun_experiments/satbird_linear_v1"
load_ckpt_path: "model_checkpoints/2_SatBird/satbird_linear"
save_preds_path: ""

comet:
project_name: "SDMPartialLabels"
tags: [ "MLP", "corrected_targets", "EnvNormalization", "satbird" ]
experiment_name: "satbird_linear_v1" # specify for training, or use to report test results, TODO: also use to resume training
experiment_name: "satbird_linear" # specify for training, or use to report test results
experiment_key: "" # use to report test results,

dataloader_to_use: "SDMEnvDataset"
Expand All @@ -28,19 +28,8 @@ training:
max_epochs: 50
accelerator: "cpu"

partial_labels:
use: false
# mask known labels out of the loss (true or false)
masked_loss: False
# quantized mask (1 if all positives to 1, > 1 to indicate bins)
quantized_mask_bins: 4
# max ratio of unknown labels during training
train_known_ratio: 0.75
# what known ratios do we consider when testing
eval_known_ratio: 0.0 # [1.0, 0.9, 0.8, 0.5]

# During testing, eval family of non-songbirds (0), or family of songbirds (1)
predict_family_of_species: 1
predict_family_of_species: -1

data:
loaders:
Expand All @@ -55,13 +44,14 @@ data:
'phihox', 'sltppt', 'sndppt' ]

files:
base: "/Users/hagerradi/Projects/SDMPartialLabels/data"
train: [ "SatBird/train_split.csv" ]
val: [ "SatBird/valid_split.csv" ]
test: [ "SatBird/test_split.csv" ]
base: "/data/SatBird_data_v2/USA_summer"
train: [ "train_split.csv" ]
val: [ "valid_split.csv" ]
test: [ "test_split.csv" ]

targets_file: [ "satbird_usa_summer_targets.pkl" ]

targets_file: [ "satbird/satbird_usa_summer_targets.pkl" ]
satbird_species_indices_path: "satbird/stats"
satbird_species_indices_path: "stats"

multi_taxa: False
per_taxa_species_count: [ 670 ]
Expand Down
Loading