Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
565c44b
add html parser
SaulLu Oct 29, 2021
f3f964e
add html parser tests
SaulLu Oct 29, 2021
7c7fa34
add dependencies to requirements
SaulLu Oct 29, 2021
a92d537
create HtmlPrepreprocessor
SaulLu Oct 29, 2021
a2b435c
add test for HtmlPreprocessor
SaulLu Oct 29, 2021
2f3e93a
fix test
SaulLu Oct 29, 2021
9c9d621
format
SaulLu Oct 29, 2021
6759584
change github workflow
SaulLu Oct 29, 2021
f9531e4
fix missing variable name
SaulLu Oct 29, 2021
4edf6c5
remove useless imports
SaulLu Oct 29, 2021
f94df44
change test function name
SaulLu Oct 29, 2021
da3a7c3
clean
SaulLu Oct 29, 2021
ca16a23
change comments
SaulLu Oct 29, 2021
8b94d31
add html specific code
SaulLu Nov 4, 2021
640d1ec
add local test to create dataset
SaulLu Nov 4, 2021
1edcabd
add print examples
SaulLu Nov 4, 2021
2dfacc4
add experiment 1 script
SaulLu Nov 5, 2021
2b73755
fix dataset name
SaulLu Nov 5, 2021
c2bba15
change script to use custom branch
SaulLu Nov 5, 2021
f6e73fb
change timing
SaulLu Nov 5, 2021
1163497
add experiments subexperiment 2
SaulLu Nov 9, 2021
02bbf2f
update subexperiment 1
SaulLu Nov 9, 2021
39f0de8
update custom code html only experiments
SaulLu Nov 9, 2021
66ce16a
make change to bsmetadata
SaulLu Nov 9, 2021
823b5e4
change test local exp
SaulLu Nov 9, 2021
8825e01
add subexp 3
SaulLu Nov 10, 2021
7e26c6f
add evaluation script for sub exp 1 and 2
SaulLu Nov 10, 2021
0adc44c
change to experiments
SaulLu Nov 12, 2021
099d33d
simply the addition of fake tags
SaulLu Nov 23, 2021
78671a6
raise error for start parsing at
SaulLu Nov 23, 2021
663a5b2
rename `_get_text_and_metadata` by `_get_text_and_update_metadata`
SaulLu Nov 23, 2021
faac61f
add comment top-down and bottum-up
SaulLu Nov 23, 2021
f1b063a
add docstring to `HtmlPreprocessor`
SaulLu Nov 23, 2021
d399191
format
SaulLu Nov 23, 2021
2d2a77c
reformulate docstring
SaulLu Nov 23, 2021
6f435ed
Merge branch 'master' into add-html-preprocessing
SaulLu Nov 25, 2021
3e7a931
format
SaulLu Nov 25, 2021
a2ea092
change dependencies for pre-processing
SaulLu Nov 25, 2021
235b936
format
SaulLu Nov 25, 2021
ee6ba3d
change dependencies in github workflow
SaulLu Nov 25, 2021
4bbc424
Merge branch 'add-html-preprocessing' into add-html-exp
SaulLu Nov 25, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[entity_preprocessing]
python -m pip install .[preprocessing]
python -m pip install -r requirements-dev.txt
- name: Test
run: |
python -m pytest tests/test_get_dataloaders.py
python -m pytest tests/test_metadata_utils.py
python -m pytest tests/test_preprocessing_utils.py
python -m pytest tests/preprocessing_tools
211 changes: 211 additions & 0 deletions bsmetadata/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import dataclasses
import gc
import json
import logging
import math
import os
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import Optional

import hydra
import torch
import torch.nn.functional as F
import wandb
from accelerate import Accelerator
from datasets.features import Value
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
from tqdm.auto import tqdm as original_tqdm
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, get_scheduler, set_seed
from transformers.trainer_utils import IntervalStrategy

from bsmetadata.input_pipeline import DataConfig, get_dataloaders


logger = logging.getLogger(__name__)


@dataclass
class CFG:
data_config: DataConfig = DataConfig()
out_dir: str = field(
default="output_dir", metadata={"help": "The output directory in which the trained model is saved."}
)
training_jobid: Optional[str] = field(default=None, metadata={"help": "The jobid of the training run."})
jobid: Optional[str] = field(default=None, metadata={"help": "The jobid of the evaluation."})
checkpoints_to_evaluate: str = field(
default="all",
metadata={
"help": "Indicate whether all checkpoints should be evaluated ('all') or only the last one ('last')"
},
)
eval_name: str = field(
default="ppl on val without metadata",
metadata={
"help": "Indicate whether all checkpoints should be evaluated ('all') or only the last one ('last')"
},)
seed: int = field(default=42, metadata={"help": "The seed used for RNG initialization."})
model_name: str = field(default="gpt2", metadata={"help": "The name of the pretrained model to use."})
project_name: str = field(default="metadata_lm", metadata={"help": "The project name."})
do_eval: bool = field(default=True, metadata={"help": "Whether to run eval on the dev set."})


cs = ConfigStore.instance()
cs.store(name="config", node=CFG)


def show_help(context="", cls=CFG):
default_instance = cls()
for field_ in dataclasses.fields(cls):
if dataclasses.is_dataclass(field_.type):
show_help(context=f"{context}{field_.name}.", cls=field_.type)
else:
kwargs = field_.metadata.copy()
# print(field)
help = kwargs.get("help", "")
default = getattr(default_instance, field_.name) # init and tell the default
print(f"{context}{field_.name}: {help} (default={json.dumps(default)})")


class Logger:
def __init__(self, is_local_main_process, *args, **kwargs):
self.is_local_main_process = is_local_main_process
if self.is_local_main_process:
self.run = wandb.init(*args, **kwargs)

def log(self, dic):
if self.is_local_main_process:
wandb.log(dic)

def close(self):
if self.is_local_main_process:
wandb.finish()


def loss_fn(batch, outputs, metadata_mask=None):
b = outputs.logits.size(0)
lm_logits = outputs.logits
labels = batch["labels"]
attention_mask = batch["attention_mask"]

shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
if metadata_mask is not None:
loss_mask = torch.logical_and(attention_mask, ~metadata_mask)
else:
loss_mask = attention_mask
shift_mask = loss_mask[..., 1:].contiguous()
# Flatten the tokens
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="none",
).view(b, -1)
loss = (loss * shift_mask).sum() / shift_mask.sum()
# per-example ppl
# ppl = torch.exp((loss * shift_mask).sum(-1) / shift_mask.sum(-1))
return loss


@hydra.main(config_path=None, config_name="config")
def main(args: CFG) -> None:
print(OmegaConf.to_yaml(args))

# The dataset library use the hash of the arguments to create the cache
# name. Without this transformation the hash of args is not deterministic
args = OmegaConf.to_object(args)

set_seed(args.seed)
accelerator = Accelerator()
is_local_main_process = accelerator.is_local_main_process
tqdm = partial(original_tqdm, disable=not is_local_main_process, position=0)

os.makedirs(args.out_dir, exist_ok=True)

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if is_local_main_process else logging.WARN,
)

# get dataloaders
logger.info("Load tokenizer")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token = tokenizer.eos_token
logger.info("Load dataloaders")

# todo trick
_ , eval_dataloaders = get_dataloaders(tokenizer, args.data_config)
logger.info("The dataloaders have been build")

if not args.do_eval:
return

# get model
logger.info("Load model")
model = AutoModelForCausalLM.from_pretrained(args.model_name)

# Prepare everything
model = accelerator.prepare(model)
eval_dataloaders = {k: accelerator.prepare(v) for k, v in eval_dataloaders.items()}

# Note -> the training dataloader needs to be prepared before we grab its length below (cause its length will be
# shorter in multiprocess)

@torch.no_grad()
def evaluate(eval_dataloader):
model.eval()
losses = []
for step, batch in enumerate(tqdm(eval_dataloader, desc="eval")): # , leave=False)
labels = batch.pop("labels")
metadata_mask = batch.pop("metadata_mask", None)
outputs = model(**batch)
batch["labels"] = labels
loss = loss_fn(batch, outputs, metadata_mask)

losses.append(accelerator.gather(loss.repeat(args.data_config.per_device_eval_batch_size)))

losses = torch.cat(losses)
perplexity = math.exp(torch.mean(losses))
model.train()
return {"perplexity": perplexity}

logger_metrics = Logger(is_local_main_process, project=args.project_name, config=args)

checkpoint_names = sorted(
[
os.path.join(args.out_dir, args.training_jobid, file_name)
for file_name in os.listdir(os.path.join(args.out_dir, args.training_jobid))
if file_name.split(".")[-1] == "pt" and file_name.split("-")[0] == "checkpoint"
]
)
if args.checkpoints_to_evaluate == "last":
checkpoint_names = [checkpoint_names[-1]]
elif args.checkpoints_to_evaluate != "all":
raise ValueError("Wrong argument set for 'checkpoints_to_evaluate', valid possibilities are 'all' or 'last'.")

logger.info(f"Will evaluate the following checkpoints: {checkpoint_names}")
for file_name in checkpoint_names:
checkpoint_path = os.path.join(args.out_dir, args.jobid, file_name)
step = file_name.split(".")[0].split("-")[-1].split("step")[0]
logger.info(f"Loading state dict for the checkpoint of step {step}")
state_dict = torch.load(checkpoint_path)["state_dict"]
logger.info("Loading state dict finished")

model.load_state_dict(state_dict)

logger.info(f"***** Evaluation step {step} *****")
for key, eval_dataloader in eval_dataloaders.items():
metrics = evaluate(eval_dataloader)
logger_metrics.log({f"{args.eval_name} {key}": metrics, "step": step})
# logger_metrics.info(f"epoch {epoch}: perplexity: {perplexity}")


if __name__ == "__main__":
if "--help" in sys.argv or "-h" in sys.argv:
show_help()
sys.exit()
main()
26 changes: 26 additions & 0 deletions bsmetadata/experiments/with_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,32 @@ def create_labels_column(examples):

logger.info(f" Num train examples = {len(train_dataset)}")
logger.info(f" Num validation examples = {len(val_dataset)}")
for idx in range(len(train_dataset)):
if 1 in train_dataset[idx]["metadata_mask"]:
logger.info(" Train sample with metadata")
logger.info(f" Train sample n°{idx} attention_mask:\n{train_dataset[idx]['attention_mask']}")
logger.info(f" Train sample n°{idx} metadata_mask:\n{train_dataset[idx]['metadata_mask']}")
logger.info(f" Train sample n°{idx} input_ids:\n{train_dataset[idx]['input_ids']}")
logger.info(
f" Train sample n°{idx} input_ids decoded:\n{tokenizer.decode(train_dataset[idx]['input_ids'])}"
)
logger.info(
f" Train sample n°{idx} tokens:\n{tokenizer.convert_ids_to_tokens(train_dataset[idx]['input_ids'])}"
)
break
for idx in range(len(train_dataset)):
if 1 not in train_dataset[idx]["metadata_mask"]:
logger.info(" Train sample without metadata")
logger.info(f" Train sample n°{idx} attention_mask:\n{train_dataset[idx]['attention_mask']}")
logger.info(f" Train sample n°{idx} metadata_mask:\n{train_dataset[idx]['metadata_mask']}")
logger.info(f" Train sample n°{idx} input_ids:\n{train_dataset[idx]['input_ids']}")
logger.info(
f" Train sample n°{idx} input_ids decoded:\n{tokenizer.decode(train_dataset[idx]['input_ids'])}"
)
logger.info(
f" Train sample n°{idx} tokens:\n{tokenizer.convert_ids_to_tokens(train_dataset[idx]['input_ids'])}"
)
break

# DataLoaders creation:
train_dataloader = DataLoader(
Expand Down
28 changes: 24 additions & 4 deletions bsmetadata/experiments/without_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def get_dataloaders(tokenizer, args):
datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
data_files=data_files,
cache_dir=args.cache_dir,
keep_in_memory=False,
)
Expand Down Expand Up @@ -103,7 +104,7 @@ def get_dataloaders(tokenizer, args):

# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = datasets["train"].column_names
column_names = datasets["train"].column_names if "train" in datasets else datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
Expand Down Expand Up @@ -172,19 +173,38 @@ def group_texts(examples):
)
logger.info("Group texts finished")

train_dataset = datasets["train"]
train_dataset = datasets["train"] if "train" in datasets else None
val_dataset = datasets["validation"]

logger.info(f" Num train examples = {len(train_dataset)}")
if "train" in datasets:
logger.info(f" Num train examples = {len(train_dataset)}")
logger.info(f" Num validation examples = {len(val_dataset)}")
if "train" in datasets:
logger.info(" Train sample without metadata")
for idx in range(3):
logger.info(f" Train sample n°{idx} attention_mask:\n{train_dataset[idx]['attention_mask']}")
logger.info(f" Train sample n°{idx} input_ids:\n{train_dataset[idx]['input_ids']}")
logger.info(f" Train sample n°{idx} input_ids decoded:\n{tokenizer.decode(train_dataset[idx]['input_ids'])}")
logger.info(
f" Train sample n°{idx} tokens:\n{tokenizer.convert_ids_to_tokens(train_dataset[idx]['input_ids'])}"
)
else:
logger.info(" Validation sample without metadata")
for idx in range(3):
logger.info(f" Validation sample n°{idx} attention_mask:\n{val_dataset[idx]['attention_mask']}")
logger.info(f" Validation sample n°{idx} input_ids:\n{val_dataset[idx]['input_ids']}")
logger.info(f" Validation sample n°{idx} input_ids decoded:\n{tokenizer.decode(val_dataset[idx]['input_ids'])}")
logger.info(
f" Validation sample n°{idx} tokens:\n{tokenizer.convert_ids_to_tokens(val_dataset[idx]['input_ids'])}"
)

# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size,
)
) if "train" in datasets else None
val_dataloader1 = DataLoader(
val_dataset,
collate_fn=default_data_collator,
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions bsmetadata/preprocessing_tools/html_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List, Optional

from bsmetadata.preprocessing_tools.html_parser.filters_and_cleaners import TextAndMetadataCleaner
from bsmetadata.preprocessing_tools.html_parser.objects import TagToRemoveWithContent


def get_clean_text_and_metadata(
html_str,
tags_to_remove_with_content: Optional[List[TagToRemoveWithContent]] = None,
tags_to_remove_alone: Optional[List[str]] = None,
attrs_to_keep: Optional[List[str]] = None,
consecutive_tags_to_fold: Optional[List[str]] = None,
convert_br_tag_to_breaking_line: Optional[bool] = False,
txt_max_chr_len_alone: float = -float("inf"),
txt_min_chr_len_alone: float = -float("inf"),
tags_exceptions_to_txt_max_min_chr_len_alone: List[str] = None,
txt_max_chr_len_with_content: float = -float("inf"),
txt_min_chr_len_with_content: float = -float("inf"),
tags_exceptions_to_txt_max_min_chr_len_with_content: List[str] = None,
):
text_and_metadata_cleaner = TextAndMetadataCleaner(
html_str=html_str,
tags_to_remove_with_content=tags_to_remove_with_content,
tags_to_remove_alone=tags_to_remove_alone,
attrs_to_keep=attrs_to_keep,
start_parsing_at_tag="body",
consecutive_tags_to_fold=consecutive_tags_to_fold,
convert_br_tag_to_breaking_line=convert_br_tag_to_breaking_line,
txt_max_chr_len_alone=txt_max_chr_len_alone,
txt_min_chr_len_alone=txt_min_chr_len_alone,
tags_exceptions_to_txt_max_min_chr_len_alone=tags_exceptions_to_txt_max_min_chr_len_alone,
txt_max_chr_len_with_content=txt_max_chr_len_with_content,
txt_min_chr_len_with_content=txt_min_chr_len_with_content,
tags_exceptions_to_txt_max_min_chr_len_with_content=tags_exceptions_to_txt_max_min_chr_len_with_content,
)
return text_and_metadata_cleaner.apply()
Loading