diff --git a/bsmetadata/hydra_configs/v2.yaml b/bsmetadata/hydra_configs/v2.yaml index ee6a6a5d..8ace0fe9 100644 --- a/bsmetadata/hydra_configs/v2.yaml +++ b/bsmetadata/hydra_configs/v2.yaml @@ -48,6 +48,7 @@ data_config: metadata_prefix_sep: ' |||' metadata_prefix_start_seq: '' max_seq_len: 1024 + apply_cm3_loss_to_sequences: false html_parser_config: all_tags_rules: attributes_to_keep: diff --git a/bsmetadata/metadata_processors.py b/bsmetadata/metadata_processors.py index 755041f6..b153051b 100644 --- a/bsmetadata/metadata_processors.py +++ b/bsmetadata/metadata_processors.py @@ -174,6 +174,10 @@ class MetadataConfig: max_seq_len: int = field( default=512, metadata={"help": "The maximum number of tokens to use for each training chunk."} ) + apply_cm3_loss_to_sequences: bool = field( + default=False, + metadata={"help": "If True, the CM3 loss will be applied to training input sequences. "}, + ) html_parser_config: Optional[HTMLParserConfig] = HTMLParserConfig( AllTagsRules( attributes_to_keep=None, diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index baa632c2..87371b5b 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -120,10 +120,31 @@ def is_metadata(idx: int) -> bool: # Create chunks of `max_seq_len` tokens. prefix_len = len(metadata_prefix_encoded) max_text_len = cfg.max_seq_len - prefix_len + if cfg.apply_cm3_loss_to_sequences: + max_text_len -= 2 for text_chunk_encoded, chunk_metadata_mask in chunks( max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask ): + if cfg.apply_cm3_loss_to_sequences: + span_ids = sorted([random.randint(0, len(text_chunk_encoded)) for x in range(2)]) + span_start, span_end = span_ids[0], span_ids[1] + if span_end - span_start > 0: + text_chunk_encoded = ( + text_chunk_encoded[:span_start] + + [tokenizer.mask_token_id] + + text_chunk_encoded[span_end:] + + [tokenizer.mask_token_id] + + text_chunk_encoded[span_start:span_end] + ) + chunk_metadata_mask = ( + chunk_metadata_mask[:span_start] + + [1] + + chunk_metadata_mask[span_end:] + + [1] + + chunk_metadata_mask[span_start:span_end] + ) + total_len = prefix_len + len(text_chunk_encoded) padding_len = max_text_len - len(text_chunk_encoded) diff --git a/bsmetadata/train.py b/bsmetadata/train.py index b0b46428..825e6eda 100644 --- a/bsmetadata/train.py +++ b/bsmetadata/train.py @@ -278,10 +278,16 @@ def main(args: CFG) -> None: new_tokens = [ AddedToken(token, rstrip=False, lstrip=False, single_word=False, normalized=False) for token in new_tokens ] - tokenizer = AutoTokenizer.from_pretrained(args.model_name, additional_special_tokens=new_tokens) else: - tokenizer = AutoTokenizer.from_pretrained(args.model_name) + new_tokens = [] + + new_tokens += [AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False)] + tokenizer = AutoTokenizer.from_pretrained(args.model_name, additional_special_tokens=new_tokens) + + tokenizer.mask_token = "" + tokenizer.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) tokenizer.pad_token = tokenizer.eos_token + if args.data_config.experiment == "with_metadata_datasetv2_tf": from bsmetadata.experiments.with_metadata_datasetv2_tf import get_dataloader, get_dummy_dataloader