From 52b69764637f3fec4be654374af0243402fcaed5 Mon Sep 17 00:00:00 2001 From: Masoud Jalili Sabet Date: Fri, 16 Dec 2022 04:43:13 +0100 Subject: [PATCH 1/2] Add CM3 loss to bsmetadata/metadata_utils.py --- bsmetadata/hydra_configs/v2.yaml | 1 + bsmetadata/metadata_processors.py | 6 ++++++ bsmetadata/metadata_utils.py | 12 ++++++++++++ bsmetadata/train.py | 10 ++++++++-- 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/bsmetadata/hydra_configs/v2.yaml b/bsmetadata/hydra_configs/v2.yaml index ee6a6a5d..8ace0fe9 100644 --- a/bsmetadata/hydra_configs/v2.yaml +++ b/bsmetadata/hydra_configs/v2.yaml @@ -48,6 +48,7 @@ data_config: metadata_prefix_sep: ' |||' metadata_prefix_start_seq: '' max_seq_len: 1024 + apply_cm3_loss_to_sequences: false html_parser_config: all_tags_rules: attributes_to_keep: diff --git a/bsmetadata/metadata_processors.py b/bsmetadata/metadata_processors.py index 755041f6..ce393592 100644 --- a/bsmetadata/metadata_processors.py +++ b/bsmetadata/metadata_processors.py @@ -174,6 +174,12 @@ class MetadataConfig: max_seq_len: int = field( default=512, metadata={"help": "The maximum number of tokens to use for each training chunk."} ) + apply_cm3_loss_to_sequences: bool = field( + default=False, + metadata={ + "help": "If True, the CM3 loss will be applied to training input sequences. " + }, + ) html_parser_config: Optional[HTMLParserConfig] = HTMLParserConfig( AllTagsRules( attributes_to_keep=None, diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index baa632c2..0552d6a0 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -120,10 +120,22 @@ def is_metadata(idx: int) -> bool: # Create chunks of `max_seq_len` tokens. prefix_len = len(metadata_prefix_encoded) max_text_len = cfg.max_seq_len - prefix_len + if cfg.apply_cm3_loss_to_sequences: + max_text_len -= 2 for text_chunk_encoded, chunk_metadata_mask in chunks( max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask ): + if cfg.apply_cm3_loss_to_sequences: + span_start, span_end = random.randint(0, len(text_chunk_encoded)), random.randint(0, len(text_chunk_encoded)) + if span_end < span_start: + span_start, span_end = span_end, span_start + if span_end - span_start > 0: + text_chunk_encoded = text_chunk_encoded[:span_start] + [tokenizer.mask_token_id] + \ + text_chunk_encoded[span_end:] + [tokenizer.mask_token_id] + text_chunk_encoded[span_start: span_end] + chunk_metadata_mask = chunk_metadata_mask[:span_start] + [1] + \ + chunk_metadata_mask[span_end:] + [1] + chunk_metadata_mask[span_start: span_end] + total_len = prefix_len + len(text_chunk_encoded) padding_len = max_text_len - len(text_chunk_encoded) diff --git a/bsmetadata/train.py b/bsmetadata/train.py index b0b46428..825e6eda 100644 --- a/bsmetadata/train.py +++ b/bsmetadata/train.py @@ -278,10 +278,16 @@ def main(args: CFG) -> None: new_tokens = [ AddedToken(token, rstrip=False, lstrip=False, single_word=False, normalized=False) for token in new_tokens ] - tokenizer = AutoTokenizer.from_pretrained(args.model_name, additional_special_tokens=new_tokens) else: - tokenizer = AutoTokenizer.from_pretrained(args.model_name) + new_tokens = [] + + new_tokens += [AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False)] + tokenizer = AutoTokenizer.from_pretrained(args.model_name, additional_special_tokens=new_tokens) + + tokenizer.mask_token = "" + tokenizer.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) tokenizer.pad_token = tokenizer.eos_token + if args.data_config.experiment == "with_metadata_datasetv2_tf": from bsmetadata.experiments.with_metadata_datasetv2_tf import get_dataloader, get_dummy_dataloader From 913364e3645239be08b427b21250b9b10776125b Mon Sep 17 00:00:00 2001 From: Masoud Jalili Sabet Date: Fri, 16 Dec 2022 05:15:50 +0100 Subject: [PATCH 2/2] add styling changes to cm3 loss --- bsmetadata/metadata_processors.py | 4 +--- bsmetadata/metadata_utils.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/bsmetadata/metadata_processors.py b/bsmetadata/metadata_processors.py index ce393592..b153051b 100644 --- a/bsmetadata/metadata_processors.py +++ b/bsmetadata/metadata_processors.py @@ -176,9 +176,7 @@ class MetadataConfig: ) apply_cm3_loss_to_sequences: bool = field( default=False, - metadata={ - "help": "If True, the CM3 loss will be applied to training input sequences. " - }, + metadata={"help": "If True, the CM3 loss will be applied to training input sequences. "}, ) html_parser_config: Optional[HTMLParserConfig] = HTMLParserConfig( AllTagsRules( diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index 0552d6a0..87371b5b 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -127,14 +127,23 @@ def is_metadata(idx: int) -> bool: max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask ): if cfg.apply_cm3_loss_to_sequences: - span_start, span_end = random.randint(0, len(text_chunk_encoded)), random.randint(0, len(text_chunk_encoded)) - if span_end < span_start: - span_start, span_end = span_end, span_start + span_ids = sorted([random.randint(0, len(text_chunk_encoded)) for x in range(2)]) + span_start, span_end = span_ids[0], span_ids[1] if span_end - span_start > 0: - text_chunk_encoded = text_chunk_encoded[:span_start] + [tokenizer.mask_token_id] + \ - text_chunk_encoded[span_end:] + [tokenizer.mask_token_id] + text_chunk_encoded[span_start: span_end] - chunk_metadata_mask = chunk_metadata_mask[:span_start] + [1] + \ - chunk_metadata_mask[span_end:] + [1] + chunk_metadata_mask[span_start: span_end] + text_chunk_encoded = ( + text_chunk_encoded[:span_start] + + [tokenizer.mask_token_id] + + text_chunk_encoded[span_end:] + + [tokenizer.mask_token_id] + + text_chunk_encoded[span_start:span_end] + ) + chunk_metadata_mask = ( + chunk_metadata_mask[:span_start] + + [1] + + chunk_metadata_mask[span_end:] + + [1] + + chunk_metadata_mask[span_start:span_end] + ) total_len = prefix_len + len(text_chunk_encoded) padding_len = max_text_len - len(text_chunk_encoded)