Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
da38891
simplify code
SaulLu Aug 27, 2021
3c1d121
change how the tags are added - difference between self closing tag a…
SaulLu Aug 27, 2021
321411f
add test for the new way to add local metadata
SaulLu Aug 27, 2021
3088c91
add extension
SaulLu Aug 27, 2021
eadde2c
create special html processor
SaulLu Aug 27, 2021
1ef6ba9
add attributes to regular html processor
SaulLu Aug 27, 2021
26385c7
add test to custom html processor
SaulLu Aug 27, 2021
8b7655d
add baby training script to test
SaulLu Aug 27, 2021
7ae45dd
change requirements
SaulLu Aug 27, 2021
4d98709
add do train attribute
SaulLu Aug 30, 2021
d0ca7dc
add strat_training_example_script
SaulLu Aug 30, 2021
a3d60c7
format + fix content min char
SaulLu Aug 30, 2021
4363a9b
aff file name in addition to dataset name
SaulLu Aug 30, 2021
d402eff
add example script
SaulLu Aug 30, 2021
070b285
add html parser dataclass
SaulLu Aug 30, 2021
830744d
add hash
SaulLu Aug 30, 2021
a8e3f4f
add html
SaulLu Aug 30, 2021
7a1b699
format + make hashables the args
SaulLu Aug 30, 2021
b3a20df
change requirements for the new method used
SaulLu Aug 30, 2021
70205a1
add script to run experiment
SaulLu Aug 31, 2021
cb3f9de
fix batch script
SaulLu Aug 31, 2021
ad808d3
fix code
SaulLu Aug 31, 2021
4618c39
change alloc
SaulLu Aug 31, 2021
c4969af
adjust bash script for preprocess
SaulLu Aug 31, 2021
e0c7246
space top file removed
SaulLu Aug 31, 2021
7335e34
remove ngpu
SaulLu Aug 31, 2021
6a77f4a
fix path
SaulLu Aug 31, 2021
18d0465
fix path
SaulLu Aug 31, 2021
8532f1b
refactor
SaulLu Aug 31, 2021
b0ddd7f
add config_path
SaulLu Aug 31, 2021
55e7cd4
add logging
SaulLu Aug 31, 2021
c5c0663
load model and dataset
SaulLu Aug 31, 2021
266a5ef
hydra config_path
SaulLu Aug 31, 2021
040cb12
add logging
SaulLu Aug 31, 2021
e33566f
set transformers login info
SaulLu Aug 31, 2021
5dd779b
add loggings
SaulLu Aug 31, 2021
e509b6d
add offilne mode
SaulLu Aug 31, 2021
4d982a2
add offline mode
SaulLu Aug 31, 2021
28a9095
add logging info
SaulLu Aug 31, 2021
2a8d1d7
add log
SaulLu Aug 31, 2021
9e68903
remove dataset offline
SaulLu Aug 31, 2021
090188f
add even more logs
SaulLu Aug 31, 2021
ca8ba98
add log cache dir
SaulLu Aug 31, 2021
a633377
test dataset avec squad
SaulLu Aug 31, 2021
eb3c079
data_files is None if empty
SaulLu Sep 1, 2021
a3d877c
data_files is None if empty
SaulLu Sep 1, 2021
07b44d2
replace squad by crime_and_punish
SaulLu Sep 1, 2021
35dbe23
create local repo
SaulLu Sep 1, 2021
59178ef
fix repo name
SaulLu Sep 1, 2021
2bdff39
add do_training
SaulLu Sep 1, 2021
6c62c27
add lines for repo init
SaulLu Sep 1, 2021
60e0e02
add experiment 1
SaulLu Sep 1, 2021
5ebc1e7
add multi_steps script
SaulLu Sep 1, 2021
4f73735
change time
SaulLu Sep 1, 2021
f9f7dfe
change time load dataset
SaulLu Sep 1, 2021
d563a63
change time
SaulLu Sep 1, 2021
5a94022
change multi_batch
SaulLu Sep 1, 2021
701f5da
remove useless file
SaulLu Sep 2, 2021
ea1219a
add htlm wieghts
SaulLu Sep 2, 2021
80a7ce9
fix logs
SaulLu Sep 3, 2021
25acad4
change to None empty data files
SaulLu Sep 6, 2021
b03cde3
see offline dataset
SaulLu Sep 6, 2021
ef90eb9
change requirements
SaulLu Sep 6, 2021
50f0776
Merge branch 'LS/html_parser' of github.com:bigscience-workshop/metad…
SaulLu Sep 6, 2021
c2ab822
remove unused files
SaulLu Sep 6, 2021
1d0d1b8
changes with metadata regarding dataset
SaulLu Sep 6, 2021
d7f7b9d
change train by adding new arguments
SaulLu Sep 6, 2021
09f55cd
Merge remote-tracking branch 'origin/LS/html_parser' into LS/html_parser
SaulLu Sep 6, 2021
af82ec0
change personnal experiment
SaulLu Sep 6, 2021
0ad2d2c
change batch size
SaulLu Sep 6, 2021
9c30ce6
change batch size
SaulLu Sep 6, 2021
5753bb0
change experiment SLURM file
SaulLu Sep 6, 2021
6f1b455
Merge branch 'LS/improve_dataset' into LS/html_parser
SaulLu Sep 6, 2021
631b4c2
Merge branch 'master' into LS/html_parser
SaulLu Sep 6, 2021
49f5ef4
add JZ doc + new dataloader
SaulLu Sep 7, 2021
72e9976
create dedicated dir
SaulLu Sep 7, 2021
eadefbe
fix comments in command
SaulLu Sep 7, 2021
01f8467
add init in experiment folder
SaulLu Sep 7, 2021
dc5b9c2
change datasetname
SaulLu Sep 7, 2021
6af9a19
add readmes
SaulLu Sep 8, 2021
263eabb
format
SaulLu Sep 8, 2021
5b35b50
experiment template readme
SaulLu Sep 8, 2021
1d9e219
temporary change for tests
SaulLu Sep 8, 2021
1166d32
format
SaulLu Sep 8, 2021
7683dd5
isort
SaulLu Sep 8, 2021
099bbe2
flake 8
SaulLu Sep 8, 2021
d025464
isort
SaulLu Sep 8, 2021
172c9b3
Merge branch 'LS/improve_dataset' into LS/html_parser
SaulLu Sep 8, 2021
4cdbe21
fix arg name
SaulLu Sep 8, 2021
32a7d34
Merge branch 'master' into LS/html_parser
SaulLu Sep 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install .
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt
- name: Test
run: |
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ This repository contains code for including metadata such as URLs, timestamps, w
## Usage

```sh
accelerate launch --fp16 train.py max_train_steps=100 num_eval=1 data_config.per_device_eval_batch_size=4
accelerate launch --fp16 train.py max_train_steps=100 eval_num_per_epoch=1 data_config.per_device_eval_batch_size=4
```

## Get Help
Expand Down
222 changes: 222 additions & 0 deletions bsmetadata/experiments/with_metadata_and_baseline_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import copy
import functools
import logging

from datasets import config, load_dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator

from bsmetadata.metadata_utils import add_metadata_and_chunk_examples


logger = logging.getLogger(__name__)


def get_dataloaders(tokenizer, args):
"""
Args:
tokenizer: a huggingface/transformers tokenizer
args: a DataConfig
Returns:
a training dataloader and one or more validation dataloaders
validation dataloaders should be in a dictionary
each dataloader should yield {str: torch.Tensor(cpu) }
dictionary keys may have 'metadata_mask'
other fields will be passed to model
note: metadata_mask should be padded
Example:
train_dataloader, val_dataloaders = get_dataloaders(...)
for batch in train_dataloader:
metadata_mask = batch.get('metadata_mask', None)
outputs = model(**batch)
metrics = loss_fn(batch, outputs, metadata_mask)
"""
# Mostly copy/paste from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
data_files = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file

if not data_files:
data_files = None

logger.info(f"Start to load dataset, the result will be cached at {config.HF_DATASETS_CACHE}")
if args.dataset_name is not None:
logger.info(
"Downloading with arguments: "
f"dataset_name={args.dataset_name}, "
f"dataset_config_name={args.dataset_config_name}, "
f"data_files={data_files}, "
f"cache_dir={args.cache_dir},"
)
# Downloading and loading a dataset from the hub.
datasets = load_dataset(
args.dataset_name,
args.dataset_config_name,
data_files=data_files,
cache_dir=args.cache_dir,
keep_in_memory=False,
)

if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
else:
logger.info("Loading dataset from extension script")
extension = args.train_file.split(".")[-1] if not args.extension else args.extension
if extension == "txt":
raise ValueError(
"You have entered a text file for the train data, but this type of file cannot contain metadata "
"columns. Wouldn't you rather have a file in json/jsonl or pandas format?"
)
if extension == "jsonl":
extension = "json"
datasets = load_dataset(extension, data_files=data_files, cache_dir=args.cache_dir)

if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
cache_dir=args.cache_dir,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
cache_dir=args.cache_dir,
)
logger.info("Dataset loaded")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

# Preprocessing the datasets.
column_names = datasets["train"].column_names

logger.info("Start to add metadata and chunk examples")

# Sets the attributes of the args object that have no influence on the calculation of the next map. This is useful
# for using the cache efficiently.
tmp_data_args = copy.deepcopy(args)
tmp_data_args.preprocessing_num_workers = 80
tmp_data_args.overwrite_cache = False
tmp_data_args.per_device_eval_batch_size = 2
tmp_data_args.per_device_train_batch_size = 2
tmp_data_args.map_batch_size = 1

# First we pre-process our text and metadata
datasets_with_metadata = datasets.map(
functools.partial(add_metadata_and_chunk_examples, tokenizer=tokenizer, cfg=tmp_data_args),
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc="Pre-process the text and metadata to create new samples",
remove_columns=column_names,
batch_size=args.map_batch_size,
)
logger.info("Add metadata and chunk examples finished")

def create_labels_column(examples):
examples["labels"] = examples["input_ids"].copy()
return examples

logger.info("Create labels column")
# Then we add the column containing the labels
datasets_with_metadata = datasets_with_metadata.map(
create_labels_column,
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc="Create labels column",
batch_size=args.map_batch_size,
)
logger.info("Creating labels column finished")

train_dataset = datasets_with_metadata["train"]
val_dataset1 = datasets_with_metadata["validation"]

# We create another validation dataset without metadata
logger.info("Start to add metadata and chunk examples")
tmp_data_args.metadata_probability = 0
val_dataset_without_metadata = datasets["validation"].map(
functools.partial(add_metadata_and_chunk_examples, tokenizer=tokenizer, cfg=tmp_data_args),
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc="Pre-process the text and metadata to create new samples",
remove_columns=column_names,
batch_size=args.map_batch_size,
)
logger.info("Add metadata and chunk examples finished")

def create_labels_column(examples):
examples["labels"] = examples["input_ids"].copy()
return examples

logger.info("Create labels column")
# Then we add the column containing the labels
val_dataset_without_metadata = val_dataset_without_metadata.map(
create_labels_column,
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc="Create labels column",
batch_size=args.map_batch_size,
)
logger.info("Creating labels column finished")
val_dataset2 = val_dataset_without_metadata

logger.info(f" Num train examples = {len(train_dataset)}")
logger.info(f" Num validation examples dataloader 1 = {len(val_dataset1)}")
logger.info(f" Num validation examples dataloader 2 = {len(val_dataset2)}")

logger.info(f" Train examples = {train_dataset[0]}")
logger.info(f" Validation examples dataloader 1 = {val_dataset1[0]}")
logger.info(f" Validation examples dataloader 2 = {val_dataset2[0]}")

logger.info(f' Train examples = {tokenizer.convert_ids_to_tokens(train_dataset[0]["input_ids"])}')
logger.info(
f' Validation examples dataloader 1 = {tokenizer.convert_ids_to_tokens(val_dataset1[0]["input_ids"])}'
)
logger.info(
f' Validation examples dataloader 2 = {tokenizer.convert_ids_to_tokens(val_dataset2[0]["input_ids"])}'
)

# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size,
)
val_dataloader1 = DataLoader(
val_dataset1,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size,
)
val_dataloader2 = DataLoader(
val_dataset2,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size,
)
return train_dataloader, {"val1": val_dataloader1, "val2": val_dataloader2}
5 changes: 5 additions & 0 deletions bsmetadata/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def get_dataloaders(tokenizer, cfg: DataConfig):
if cfg.experiment == "with_metadata":
from bsmetadata.experiments.with_metadata import get_dataloaders as fn

return fn(tokenizer, cfg)

if cfg.experiment == "with_metadata_and_baseline_val":
from bsmetadata.experiments.with_metadata_and_baseline_val import get_dataloaders as fn

return fn(tokenizer, cfg)
else:
raise ValueError("You have not entered a valid experience name")
8 changes: 7 additions & 1 deletion bsmetadata/metadata_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ class HtmlProcessor(MetadataProcessor):
def process_local(self, metadata_attrs: Dict[str, Any]) -> Optional[Tuple[str, str]]:
# We represent a html tag `T` by enclosing the corresponding text span with "<T>" and "</T>".
# Example: An <b>apple</b> is an edible fruit.
return f"<{metadata_attrs['value']}>", f"</{metadata_attrs['value']}>"
attributes = " ".join(
f'{attr}:"{value}"'
for attr, value in zip(metadata_attrs["value"]["attrs"]["attr"], metadata_attrs["value"]["attrs"]["value"])
)
if attributes:
attributes = " " + attributes
return f"<{metadata_attrs['value']['tag']}{attributes}>", f"</{metadata_attrs['value']['tag']}>"


class UrlProcessor(MetadataProcessor):
Expand Down
62 changes: 51 additions & 11 deletions bsmetadata/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple

from transformers import PreTrainedTokenizerFast
Expand Down Expand Up @@ -121,6 +122,14 @@ def create_global_metadata_prefix(example: Dict[str, Any], cfg: MetadataConfig)
return cfg.metadata_sep.join(sorted_metadata) + cfg.global_metadata_sep if sorted_metadata else ""


@dataclass
class MetadataIdxStorage:
start_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list)))
end_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list)))
start_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list)))
end_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list)))


def add_local_metadata_to_text(example: Dict[str, Any], cfg: MetadataConfig) -> Tuple[str, List[bool]]:
"""Adds local metadata (such as HTML tags and entity names) to the given input text.

Expand All @@ -133,7 +142,7 @@ def add_local_metadata_to_text(example: Dict[str, Any], cfg: MetadataConfig) ->
- the first element is the text with metadata;
- the second element is a boolean mask where `mask[i]` is set iff `text[i]` is some kind of metadata.
"""
metadata_start_texts, metadata_end_texts = defaultdict(list), defaultdict(list)
metadata_idx_storage = MetadataIdxStorage()

# Filter and sort all metadata so that they are processed in the requested order.
filtered_metadata = [md for md in example["metadata"] if md["type"] == "local" and md["key"] in cfg.metadata_list]
Expand All @@ -151,27 +160,58 @@ def add_local_metadata_to_text(example: Dict[str, Any], cfg: MetadataConfig) ->
char_start_idx = metadata.get("char_start_idx", -1)
char_end_idx = metadata.get("char_end_idx", -1)

metadata_start_texts[char_start_idx].insert(0, start_text)
metadata_end_texts[char_end_idx].append(end_text)
if char_start_idx == char_end_idx:
metadata_idx_storage.start_idx_tag_without_content[char_start_idx].insert(0, start_text)
metadata_idx_storage.end_idx_tag_without_content[char_end_idx].append(end_text)
else:
metadata_idx_storage.start_idx_tag_with_content[char_start_idx].insert(0, start_text)
metadata_idx_storage.end_idx_tag_with_content[char_end_idx].append(end_text)

# Build the final text with local metadata and the corresponding mask.
text_with_local_metadata = []
metadata_mask = []

def _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask):
for metadata_text in metadata_text_list:
text_with_local_metadata.append(metadata_text)
metadata_mask += [True] * len(metadata_text)

for idx, char in enumerate(example["text"]):
if idx in metadata_idx_storage.end_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.start_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_start_texts:
for start_text in metadata_start_texts[idx]:
text_with_local_metadata.append(start_text)
metadata_mask += [True] * len(start_text)
if idx in metadata_idx_storage.end_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.start_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

text_with_local_metadata.append(char)
metadata_mask += [False]

if idx + 1 in metadata_end_texts:
for end_text in metadata_end_texts[idx + 1]:
text_with_local_metadata.append(end_text)
metadata_mask += [True] * len(end_text)
idx += 1
if idx in metadata_idx_storage.end_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.start_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.end_idx_tag_without_content:
metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

if idx in metadata_idx_storage.start_idx_tag_with_content:
metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx]
_add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)

return "".join(text_with_local_metadata), metadata_mask

Expand Down
Loading