From 11807a87a183f1a827f84617665267750bf1e0ef Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sun, 8 Jan 2023 21:02:44 +0800 Subject: [PATCH 01/13] Fix mask bug --- bsmetadata/evaluation.py | 2 +- bsmetadata/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bsmetadata/evaluation.py b/bsmetadata/evaluation.py index 38bdf6ab..7285dd94 100644 --- a/bsmetadata/evaluation.py +++ b/bsmetadata/evaluation.py @@ -37,7 +37,7 @@ def ppl_fn( shift_labels = labels[..., 1:].contiguous() if metadata_mask is not None: - loss_mask = torch.logical_and(attention_mask, ~metadata_mask) + loss_mask = torch.logical_and(attention_mask, ~(metadata_mask.bool())) else: loss_mask = attention_mask diff --git a/bsmetadata/train.py b/bsmetadata/train.py index b0b46428..32cd8874 100644 --- a/bsmetadata/train.py +++ b/bsmetadata/train.py @@ -136,7 +136,7 @@ def loss_fn(batch, outputs, metadata_mask=None): shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() if metadata_mask is not None: - loss_mask = torch.logical_and(attention_mask, ~metadata_mask) + loss_mask = torch.logical_and(attention_mask, ~(metadata_mask.bool())) else: loss_mask = attention_mask shift_mask = loss_mask[..., 1:].contiguous() From 564569149ac45c1525514d5095fa5d1425039493 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sun, 8 Jan 2023 21:17:46 +0800 Subject: [PATCH 02/13] Fix file list --- bsmetadata/experiments/datasetv2.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/bsmetadata/experiments/datasetv2.py b/bsmetadata/experiments/datasetv2.py index ea4f32ec..3a1969df 100644 --- a/bsmetadata/experiments/datasetv2.py +++ b/bsmetadata/experiments/datasetv2.py @@ -265,15 +265,6 @@ "c4-en-html_cc-main-2019-18_pq01-009.jsonl.gz", "c4-en-html_cc-main-2019-18_pq01-010.jsonl.gz", "c4-en-html_cc-main-2019-18_pq01-011.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-012.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-013.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-014.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-016.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-017.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-018.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-019.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-020.jsonl.gz", - "c4-en-html_cc-main-2019-18_pq01-021.jsonl.gz", ] features = { From a6eaeafb7feb6331afecbb964d1b37b1ad8861a7 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 12 Dec 2022 23:04:43 +0800 Subject: [PATCH 03/13] Don't raise error on extra import --- .../preprocessing_tools/wikipedia_desc_utils.py | 12 +++++++++++- bsmetadata/preprocessing_utils.py | 15 +++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/bsmetadata/preprocessing_tools/wikipedia_desc_utils.py b/bsmetadata/preprocessing_tools/wikipedia_desc_utils.py index c2a0fe14..957f157f 100644 --- a/bsmetadata/preprocessing_tools/wikipedia_desc_utils.py +++ b/bsmetadata/preprocessing_tools/wikipedia_desc_utils.py @@ -4,11 +4,21 @@ from typing import Optional import nltk -from wikipedia2vec.dump_db import DumpDB + + +try: + from wikipedia2vec.dump_db import DumpDB +except ImportError: + wikipedia2vec_available = False class WikipediaDescUtils: def __init__(self, path_wiki_db) -> None: + if not wikipedia2vec_available: + raise ImportError( + "Please install wikipedia2vec to use this feature. " + "You can do so by running `pip install -e .'website_description'`." + ) self.cache = defaultdict(str) self.wiki_dump_db = DumpDB(path_wiki_db) self.redirects_map = { diff --git a/bsmetadata/preprocessing_utils.py b/bsmetadata/preprocessing_utils.py index 2f635a18..cfa93db6 100644 --- a/bsmetadata/preprocessing_utils.py +++ b/bsmetadata/preprocessing_utils.py @@ -24,10 +24,15 @@ from bs_dateutil.parser import ParserError, parse from datasets import Value -from REL.entity_disambiguation import EntityDisambiguation -from REL.mention_detection import MentionDetection -from REL.ner import load_flair_ner -from REL.utils import process_results + + +try: + from REL.entity_disambiguation import EntityDisambiguation + from REL.mention_detection import MentionDetection + from REL.ner import load_flair_ner + from REL.utils import process_results +except ImportError: + REL_available = False from bsmetadata.paragraph_by_metadata_html import get_paragraphs from bsmetadata.preprocessing_tools import html_parser @@ -373,6 +378,8 @@ def __init__( col_to_store_metadata="metadata", col_text="text", ): + if not REL_available: + raise ImportError("REL is not available. Please install the extra with `pip install -e '.[entity]'`") self.base_url = base_url self.wiki_version = "wiki_2019" self.mention_detection = MentionDetection(self.base_url, self.wiki_version) From 1b8c6879d7e1d0e439ff8434276db5a96800f15a Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 16 Jan 2023 23:24:04 +0800 Subject: [PATCH 04/13] WIP --- bsmetadata/deepspeed_configs/v2.json | 6 +++--- bsmetadata/hydra_configs/v2.yaml | 12 ++++++------ bsmetadata/metadata_utils.py | 2 +- experiments/hpsearch/test.sh | 10 +++++++--- requirements.txt | 1 + 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/bsmetadata/deepspeed_configs/v2.json b/bsmetadata/deepspeed_configs/v2.json index 015c906b..1d5c0311 100644 --- a/bsmetadata/deepspeed_configs/v2.json +++ b/bsmetadata/deepspeed_configs/v2.json @@ -37,12 +37,12 @@ "reduce_scatter": true, "reduce_bucket_size": 500000000, "contiguous_gradients": true, - "cpu_offload": false + "cpu_offload": true }, - "gradient_accumulation_steps": 1, + "gradient_accumulation_steps": 16, "gradient_clipping": "auto", "steps_per_print": 100, "train_batch_size": 256, "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false -} +} \ No newline at end of file diff --git a/bsmetadata/hydra_configs/v2.yaml b/bsmetadata/hydra_configs/v2.yaml index ee6a6a5d..742b16db 100644 --- a/bsmetadata/hydra_configs/v2.yaml +++ b/bsmetadata/hydra_configs/v2.yaml @@ -75,9 +75,9 @@ data_config: local_metadata_special_token_end: entity_paragraph: " " local_metadata_special_token_state: true - experiment: with_metadata_datasetv2 - per_device_eval_batch_size: 32 - per_device_train_batch_size: 32 + experiment: with_metadata_datasetv2_tf + per_device_eval_batch_size: 8 + per_device_train_batch_size: 8 dataset_name: bs-modeling-metadata/c4-en-html-with-metadata dataset_config_name: null train_file: c4-en-html_cc-main-2019-18_pq00-001.jsonl.gz @@ -108,9 +108,9 @@ eval_num_per_epoch: 3 eval_steps: 2000 save_strategy: STEPS save_num_per_epoch: 3 -save_steps: 2000 +save_steps: 150 do_train: true do_eval: true gradient_checkpointing: true -resume_from_checkpoint_dir: null -gradient_accumulation_steps: 1 +resume_from_checkpoint_dir: null +gradient_accumulation_steps: 16 diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index baa632c2..f3f942e7 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -208,7 +208,7 @@ def random_sample_metadata_v2( Returns: A new collection of examples, with some metadata dropped. """ - only_metadata_types = list(metadata_type_sample_weights.keys()) + only_metadata_types = [key for key in metadata_type_sample_weights.keys() if f"metadata_{key}" in examples] for i in range(len(examples["text"])): example = {k: v[i] for k, v in examples.items()} metadata_types = [key for key in only_metadata_types if example[f"metadata_{key}"]] diff --git a/experiments/hpsearch/test.sh b/experiments/hpsearch/test.sh index 8aa6ba3b..01c0195c 100644 --- a/experiments/hpsearch/test.sh +++ b/experiments/hpsearch/test.sh @@ -1,7 +1,8 @@ export MODEL=gpt2-xl -export NUM_GPU=8 +export NUM_GPU=2 export DEEPSPEED_CONFIG=$(realpath bsmetadata/deepspeed_configs/v2.json) +export DATA_DIR=$(realpath local-data) echo "deepspeed_config_file: $DEEPSPEED_CONFIG" echo "compute_environment: LOCAL_MACHINE deepspeed_config: @@ -21,7 +22,10 @@ accelerate launch --config_file accelerate_config.yaml bsmetadata/train.py --con model_name=$MODEL \ data_config.train_file='*.jsonl.gz' \ data_config.validation_file='c4-en-html_cc-main-2019-18_pq00-001.jsonl.gz' \ - data_config.preprocessing_num_workers=48 extra_steps_to_eval_save_at='[2]' \ - data_config.streaming=True out_dir=/mnt/ssd-1/bigscience-metadata/lower-lr + data_config.dataset_name=$DATA_DIR \ + data_config.preprocessing_num_workers=6 extra_steps_to_eval_save_at='[2,100,200,400,800]' \ + data_config.metadata_config.metadata_list='[html]' \ + data_config.metadata_config.metadata_column_list='[html]' \ + out_dir=$HOME/tmp/metadata-run-html #out_dir=/mnt/ssd-1/bigscience-metadata/run1 #data_config.train_file='c4-en-html_cc*.jsonl.gz' data_config.streaming=True out_dir=/mnt/ssd-1/bigscience-metadata/run1 diff --git a/requirements.txt b/requirements.txt index 21dad12b..7214f13e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ lxml==4.6.5 htmlmin==0.1.12 loguru>=0.6.0 deepspeed>=0.6.1 +tensorflow # for data processing From 151d3a3d19481651d023b32ce663138c8ab0253a Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Fri, 20 Jan 2023 20:43:29 +0800 Subject: [PATCH 05/13] stash --- bsmetadata/evaluation.py | 171 +++++++++++++++++- bsmetadata/hydra_configs/v2.yaml | 2 +- bsmetadata/metadata_utils.py | 4 +- requirements_resolved_with_extras_and_dev.txt | 2 +- 4 files changed, 166 insertions(+), 13 deletions(-) diff --git a/bsmetadata/evaluation.py b/bsmetadata/evaluation.py index 7285dd94..097b4f51 100644 --- a/bsmetadata/evaluation.py +++ b/bsmetadata/evaluation.py @@ -1,17 +1,32 @@ import argparse import functools +import itertools +import json from typing import Dict +import rich import torch import torch.nn.functional as F from datasets import load_dataset from huggingface_hub import hf_hub_download from omegaconf import OmegaConf +from rich.text import Text from tqdm.auto import tqdm + +from bsmetadata.metadata_utils import add_metadata_and_chunk_examples from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from bsmetadata.metadata_utils import add_metadata_and_chunk_examples + +def format_by_one_mask(input_ids, mask, tokenizer): + i = 0 + data = [] + for key, igroup in itertools.groupby(mask): + size = len(list(igroup)) + text = tokenizer.decode(input_ids[i : i + size]) + i += size + data.append((text, "green" if key else None)) + return Text.assemble(*data) @torch.no_grad() @@ -35,20 +50,69 @@ def ppl_fn( shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - if metadata_mask is not None: - loss_mask = torch.logical_and(attention_mask, ~(metadata_mask.bool())) + rich.print(f"{~(metadata_mask.bool())=}") + rich.print("(attention_mask.bool())") + rich.print(attention_mask.bool()) + loss_mask = torch.logical_and(attention_mask.bool(), ~(metadata_mask.bool())) + rich.print(f"{loss_mask=}") else: - loss_mask = attention_mask - + loss_mask = attention_mask.bool() shift_mask = loss_mask[..., 1:].contiguous() + """ + + max len: 10 + (label, by convention, is unshifted) + label: a b c d e f g x x x + input: a b c d e f g x x x + mask : 1 1 1 1 1 1 1 0 0 0 + + shift label : b c d e f g x x x + shift logit : a b c d e f g x x + shift a mask: 1 1 1 1 1 1 0 0 0 + + + calculated part + input: a b c d e f + label: b c d e f g + + metdata example: + label : M M a b c d e f g x + input : M M a b c d e f g x + a mask: 1 1 1 1 1 1 1 1 1 0 + m mask: 1 1 0 0 0 0 0 0 0 0 + a & !m: 0 0 1 1 1 1 1 1 1 0 + + shift label : M a b c d e f g x + shift logit : M M a b c d e f g + shift a mask: 1 1 1 1 1 1 1 1 0 + shift (a&!m): 0 1 1 1 1 1 1 1 0 + diff (bug) : x + + # fix: mask out the loss if ((the source token is metadata) or (the target token is padding)) + # + + shift m mask: + ideal mask : + """ + + # if metadata_mask is not None: + # shift_metadata_mask = metadata_mask[..., 1:].contiguous().bool() + # shift_mask = torch.logical_and(shift_mask, ~shift_metadata_mask) + rich.print(f"shift_mask{shift_mask}") + rich.print(f"{shift_mask.sum()=}") + # Flatten the tokens loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none", ).view(b, -1) + loss = loss.cpu().squeeze().numpy().tolist() + shift_mask = shift_mask.cpu().squeeze().numpy().tolist() + return loss, shift_mask, shift_labels.cpu().squeeze().numpy().tolist() + return loss, shift_mask # Normalize to avoid an overflow when there are many tokens normed_loss_weights = shift_mask / shift_mask.sum() @@ -104,18 +168,45 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: action="store_true", help="If set to true, the script runs in test mode and only takes 10 examples per dataset", ) + parser.add_argument( + "--local", + action="store_true", + help="If set to true, the script runs in test mode and only takes 10 examples per dataset", + ) + parser.add_argument( + "--metadata_to_test", + type=str, + default="html,entity,entity_paragraph,website_desc,generation_datasource,timestamp,title,generation_length_sentence,generation_length_text,url,paragraph", + help="metadata types to test", + ) + parser.add_argument( + "--untrained", + action="store_true", + help="If set to true, will load gpt2-xl", + ) args = parser.parse_args() print(f"Parameters: {args}") # Load config - config_file_path = hf_hub_download(repo_id=args.repo_id, filename="actual_config.yaml", use_auth_token=True) + if args.local: + import os + + config_file_path = os.path.join(args.repo_id, "actual_config.yaml") + else: + config_file_path = hf_hub_download(repo_id=args.repo_id, filename="actual_config.yaml", use_auth_token=True) repo_args = OmegaConf.load(config_file_path) data_config = repo_args.data_config + # make sure loss (ppl) masking is on for local metadata + data_config.metadata_config.treat_local_metadata_as_regular_text = False + # Load model print("Loading model...") - model = AutoModelForCausalLM.from_pretrained(args.repo_id, subfolder=args.subfolder, use_auth_token=True) + if args.untrained: + model = AutoModelForCausalLM.from_pretrained("gpt2-xl") + else: + model = AutoModelForCausalLM.from_pretrained(args.repo_id, subfolder=args.subfolder, use_auth_token=True) model.eval().cuda() if not args.no_cuda else model.eval() # Load tokenizer @@ -130,6 +221,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: preprocess_fn = functools.partial(add_metadata_and_chunk_examples, tokenizer=tokenizer, cfg=cfg) # Validation datasets + dataset_paths = [ "bs-modeling-metadata/c4-en-html-with-validation_metadata_html", "bs-modeling-metadata/c4-en-html-with-validation_metadata_entity", @@ -143,6 +235,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: "bs-modeling-metadata/c4-en-html-with-validation_metadata_url", "bs-modeling-metadata/c4-en-html-with-validation_metadata_paragraph", ] + dataset_paths = [path for path in dataset_paths if path.split("_metadata_")[1] in args.metadata_to_test.split(",")] for path in dataset_paths: n_examples = 0 @@ -176,6 +269,10 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: normal_example = tokenizer(examples["text"][0]) normal_example_len = len(normal_example["input_ids"]) metadata_example = {k: v[0] for k, v in processed_examples.items()} + # rich.print(f"{metadata_example['attention_mask']=}") + # rich.print(f"{normal_example['attention_mask']=}") + # import sys + # sys.exit() metadata_example_len = len(metadata_example["input_ids"]) min_seq_len = min(normal_example_len, metadata_example_len) max_seq_len = max(normal_example_len, metadata_example_len) @@ -197,12 +294,68 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: if not args.no_cuda: normal_batch = {k: v.cuda() for k, v in normal_batch.items()} metadata_batch = {k: v.cuda() for k, v in metadata_batch.items()} + if n_examples == 1: + ex = format_by_one_mask(normal_batch["input_ids"][0], normal_batch["attention_mask"][0], tokenizer) + rich.print(f"Normal example:") + rich.print(ex) + + ex = format_by_one_mask( + metadata_batch["input_ids"][0], metadata_batch["metadata_mask"][0], tokenizer + ) + rich.print(f"Metadata example:") + rich.print(ex) + rich.print(tokenizer.decode(metadata_batch["input_ids"][0])) # Calculate ppl normal_ppl = get_ppl(normal_batch) - total_normal_ppl += float(normal_ppl) * normal_example_len + # total_normal_ppl += float(normal_ppl) * normal_example_len metadata_ppl = get_ppl(metadata_batch) - total_metadata_ppl += float(metadata_ppl) * metadata_example_len + # total_metadata_ppl += float(metadata_ppl) * metadata_example_len + if n_examples == 1: + loss, mask, shift_labels = normal_ppl + printed = 0 + for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)): + if m: + if printed < 10: + rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}") + printed += 1 + + unmasked_labels = [label for label, m in zip(shift_labels, mask) if m] + # print(f"first 10 unmasked labels: {[tokenizer.decode(x) for x in unmasked_labels[:10]]}") + print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}") + # ex = format_by_one_mask(normal_batch["input_ids"][0], mask, tokenizer) + # rich.print(ex) + + loss, mask, shift_labels = metadata_ppl + printed = 0 + for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)): + if m: + if printed < 10: + rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}") + printed += 1 + + unmasked_labels = [label for label, m in zip(shift_labels, mask) if m] + print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}") + # ex = format_by_one_mask(metadata_batch["input_ids"][0], mask, tokenizer) + # rich.print(ex) + + # ex = format_by_one_mask(normal_batch["input_ids"][0], normal_batch["attention_mask"][0], tokenizer) + # rich.print(ex) + # rich.print(f"Normal example: (ppl={normal_ppl[0]})") + + # ex = format_by_one_mask( + # metadata_batch["input_ids"][0], metadata_batch["metadata_mask"][0], tokenizer + # ) + # rich.print(ex) + # rich.print(f"Metadata example: (ppl={metadata_ppl[0]})") + # rich.print(f"Normal example: (mask={normal_ppl[1]})") + # rich.print(f"Metadata example: (mask={metadata_ppl[1]})") + import sys + + sys.exit() + + if n_examples > 1000: + break if exit_flag: continue diff --git a/bsmetadata/hydra_configs/v2.yaml b/bsmetadata/hydra_configs/v2.yaml index 742b16db..5a48f12d 100644 --- a/bsmetadata/hydra_configs/v2.yaml +++ b/bsmetadata/hydra_configs/v2.yaml @@ -38,7 +38,7 @@ data_config: #- generation_length_sentence #- generation_length_text - entity_paragraph - local_metadata_special_tokens: + local_metadata_special_tokens: entity_paragraph: "entity" metadata_sep: ' | ' metadata_key_value_sep: ': ' diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index f3f942e7..5f433e65 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -20,9 +20,9 @@ from typing import Any, DefaultDict, Dict, List, Optional, Tuple import numpy as np -from transformers import PreTrainedTokenizerFast from bsmetadata.metadata_processors import PROCESSORS, MetadataConfig, MetadataProcessor +from transformers import PreTrainedTokenizerFast logger = logging.getLogger(__name__) @@ -100,7 +100,7 @@ def add_metadata_and_chunk_examples( char_level_metadata_mask = [False] * len(text_with_local_metadata) if metadata_prefix_encoded: - text_with_local_metadata = " " + text_with_local_metadata + text_with_local_metadata = "" + text_with_local_metadata char_level_metadata_mask = [False] + char_level_metadata_mask text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata) diff --git a/requirements_resolved_with_extras_and_dev.txt b/requirements_resolved_with_extras_and_dev.txt index 99287699..bdada4a1 100644 --- a/requirements_resolved_with_extras_and_dev.txt +++ b/requirements_resolved_with_extras_and_dev.txt @@ -110,7 +110,7 @@ threadpoolctl==3.1.0 ; python_full_version >= "3.7.11" and python_version < "3.9 tokenizers==0.13.1 ; python_full_version >= "3.7.11" and python_version < "3.9" toml==0.10.2 ; python_full_version >= "3.7.11" and python_version < "3.9" tomli==2.0.1 ; python_full_version >= "3.7.11" and python_version < "3.9" -torch==1.9.0 ; python_full_version >= "3.7.11" and python_version < "3.9" +#torch==1.9.0 ; python_full_version >= "3.7.11" and python_version < "3.9" tqdm==4.64.1 ; python_full_version >= "3.7.11" and python_version < "3.9" transformers==4.23.1 ; python_full_version >= "3.7.11" and python_version < "3.9" typed-ast==1.5.4 ; python_version < "3.8" and implementation_name == "cpython" and python_full_version >= "3.7.11" From d100c26eff7df2446aa081972c37bcc35d0da974 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Fri, 20 Jan 2023 22:38:18 +0800 Subject: [PATCH 06/13] Fix --- bsmetadata/evaluation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bsmetadata/evaluation.py b/bsmetadata/evaluation.py index 097b4f51..df80857c 100644 --- a/bsmetadata/evaluation.py +++ b/bsmetadata/evaluation.py @@ -51,10 +51,14 @@ def ppl_fn( shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() if metadata_mask is not None: + metadata_mask = metadata_mask.bool() + nonmetadata_cumsum = torch.cumsum(~metadata_mask, dim=-1) + first_nonmetadata = nonmetadata_cumsum == 1 rich.print(f"{~(metadata_mask.bool())=}") rich.print("(attention_mask.bool())") rich.print(attention_mask.bool()) loss_mask = torch.logical_and(attention_mask.bool(), ~(metadata_mask.bool())) + loss_mask = torch.logical_and(loss_mask, ~first_nonmetadata) rich.print(f"{loss_mask=}") else: loss_mask = attention_mask.bool() @@ -313,6 +317,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: # total_metadata_ppl += float(metadata_ppl) * metadata_example_len if n_examples == 1: loss, mask, shift_labels = normal_ppl + print("normal ppl") printed = 0 for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)): if m: @@ -328,6 +333,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: loss, mask, shift_labels = metadata_ppl printed = 0 + print("metadata ppl") for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)): if m: if printed < 10: From 6c3cc9f3f427cecf5209f90ddbc0e51716947750 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sat, 21 Jan 2023 11:10:25 +0800 Subject: [PATCH 07/13] Empty commit From 0a2aab1deeb2c868e5ff3625de168f252de2b409 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sat, 21 Jan 2023 11:19:21 +0800 Subject: [PATCH 08/13] empty commit From cd08a93edb47e3909bbb4963446f1a0b4b51c08d Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sat, 21 Jan 2023 11:21:59 +0800 Subject: [PATCH 09/13] test 2 From 0b4846daf7314d48e9e2c44e0ee45dd3b1f7f687 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Tue, 24 Jan 2023 18:50:23 +0400 Subject: [PATCH 10/13] Add loss plotting (#178) --- bsmetadata/evaluation.py | 44 ++++++++++++++++++++++++++++++++------- bsmetadata/plot_losses.py | 34 ++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 7 deletions(-) create mode 100644 bsmetadata/plot_losses.py diff --git a/bsmetadata/evaluation.py b/bsmetadata/evaluation.py index df80857c..fcc40e0a 100644 --- a/bsmetadata/evaluation.py +++ b/bsmetadata/evaluation.py @@ -31,7 +31,11 @@ def format_by_one_mask(input_ids, mask, tokenizer): @torch.no_grad() def ppl_fn( - batch: Dict[str, torch.Tensor], outputs: CausalLMOutputWithCrossAttentions, metadata_mask: torch.Tensor = None + batch: Dict[str, torch.Tensor], + outputs: CausalLMOutputWithCrossAttentions, + metadata_mask: torch.Tensor = None, + save_data: bool = False, + idx: int = None, ) -> torch.Tensor: """Calculates the perplexity for a given batch. @@ -39,6 +43,8 @@ def ppl_fn( batch: A dict with keys "input_ids" and "attention_mask". outputs: The model outputs for the batch. metadata_mask: 1 for tokens corresponding to metadata and 0 for all other tokens. + save_data: Whether to tokens & losses. + idx: The index of the batch. Returns: The perplexity of the given batch. @@ -113,8 +119,22 @@ def ppl_fn( shift_labels.view(-1), reduction="none", ).view(b, -1) + + if save_data: + # Save the non-masked tokens & their loss + suffix = "_meta" if metadata_mask is not None else "" + torch.save( + batch["input_ids"], + f"{idx}_input_ids{suffix}.pt", + ) + torch.save( + loss.cpu().squeeze(), + f"{idx}_loss{suffix}.pt", + ) + loss = loss.cpu().squeeze().numpy().tolist() shift_mask = shift_mask.cpu().squeeze().numpy().tolist() + return loss, shift_mask, shift_labels.cpu().squeeze().numpy().tolist() return loss, shift_mask @@ -129,7 +149,11 @@ def ppl_fn( @torch.no_grad() -def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: +def get_ppl( + batch: Dict[str, torch.Tensor], + save_data: bool = False, + idx: int = None, +) -> torch.Tensor: """Prepares the arguments for perplexity calculation and passes them to the perplexity function. Args: @@ -137,7 +161,8 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: - the input ids are a list of token ids corresponding to the input text with metadata; - the attention mask is 0 for padding tokens and 1 everywhere else; - the metadata mask is 1 for tokens corresponding to metadata and 0 for all other tokens. - + save_data: Whether to save tokens & losses + idx: The index of the batch for saving Returns: The perplexity of the given batch. """ @@ -145,7 +170,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: metadata_mask = batch.pop("metadata_mask", None) outputs = model(**batch) batch["labels"] = labels - ppl = ppl_fn(batch, outputs, metadata_mask) + ppl = ppl_fn(batch, outputs, metadata_mask, save_data=save_data, idx=idx) return ppl @@ -172,6 +197,11 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: action="store_true", help="If set to true, the script runs in test mode and only takes 10 examples per dataset", ) + parser.add_argument( + "--save_data", + action="store_true", + help="If set to true, save tokens & losses", + ) parser.add_argument( "--local", action="store_true", @@ -255,7 +285,7 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: split = "validation" if not args.test else "validation[:10]" validation_dataset = load_dataset(path, use_auth_token=True, split=split) - for example in tqdm(validation_dataset, desc=f"Calculating perplexity for {metadata_type}..."): + for idx, example in tqdm(enumerate(validation_dataset), desc=f"Calculating perplexity for {metadata_type}..."): # Preprocess examples examples = {k: [v] for k, v in example.items()} try: @@ -311,9 +341,9 @@ def get_ppl(batch: Dict[str, torch.Tensor]) -> torch.Tensor: rich.print(tokenizer.decode(metadata_batch["input_ids"][0])) # Calculate ppl - normal_ppl = get_ppl(normal_batch) + normal_ppl = get_ppl(normal_batch, save_data=args.save_data, idx=idx) # total_normal_ppl += float(normal_ppl) * normal_example_len - metadata_ppl = get_ppl(metadata_batch) + metadata_ppl = get_ppl(metadata_batch, save_data=args.save_data, idx=idx) # total_metadata_ppl += float(metadata_ppl) * metadata_example_len if n_examples == 1: loss, mask, shift_labels = normal_ppl diff --git a/bsmetadata/plot_losses.py b/bsmetadata/plot_losses.py new file mode 100644 index 00000000..c9571e83 --- /dev/null +++ b/bsmetadata/plot_losses.py @@ -0,0 +1,34 @@ +import argparse + +import torch +from transformers import AutoTokenizer + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--idx", + type=int, + default=12, + help="Index of the loss to plot", + ) + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + # Load the losses + input_ids_meta = torch.load(f"{args.idx}_input_ids_meta.pt") # [batch_size, seq_len] + input_ids = torch.load(f"{args.idx}_input_ids.pt") # [batch_size, seq_len] + + loss_meta = torch.load(f"{args.idx}_loss_meta.pt") # [batch_size, seq_len] + loss = torch.load(f"{args.idx}_loss.pt") # [batch_size, seq_len] + + # Print the losses + print("Meta") + tok_losses = [(tokenizer.decode(input_ids_meta[..., i]), round(loss_meta[..., i].item(), 2)) for i in range(input_ids_meta.shape[-1]-1)] + print(tok_losses) + + print("Normal") + tok_losses = [(tokenizer.decode(input_ids[..., i]), round(loss[..., i].item(), 2)) for i in range(input_ids.shape[-1]-1)] + print(tok_losses) + + From 705ceffca7467ec063f59423fd18a932cb4ff06d Mon Sep 17 00:00:00 2001 From: Jonathan Chang <31893406+cccntu@users.noreply.github.com> Date: Thu, 26 Jan 2023 14:08:39 +0800 Subject: [PATCH 11/13] add notebook --- bsmetadata/nb_plot.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 bsmetadata/nb_plot.py diff --git a/bsmetadata/nb_plot.py b/bsmetadata/nb_plot.py new file mode 100644 index 00000000..7babf987 --- /dev/null +++ b/bsmetadata/nb_plot.py @@ -0,0 +1,35 @@ +# %% +import argparse + +import torch + +from transformers import AutoTokenizer + + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +# %% +idx = 0 +# %% + +# Load the losses +input_ids_meta = torch.load(f"{idx}_input_ids_meta.pt") # [batch_size, seq_len] +input_ids = torch.load(f"{idx}_input_ids.pt") # [batch_size, seq_len] + +loss_meta = torch.load(f"{idx}_loss_meta.pt") # [batch_size, seq_len] +loss = torch.load(f"{idx}_loss.pt") # [batch_size, seq_len] + +# %% +# Print the losses +print("Meta") +tok_losses = [ + (tokenizer.decode(input_ids_meta[..., i]), round(loss_meta[..., i].item(), 2)) + for i in range(input_ids_meta.shape[-1] - 1) +] +print(tok_losses[:32]) + +print("Normal") +tok_losses = [ + (tokenizer.decode(input_ids[..., i]), round(loss[..., i].item(), 2)) for i in range(input_ids.shape[-1] - 1) +] +print(tok_losses[:32]) +# %% From f22d497ea9332af0fca06c4c171e9632ad34f36c Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Fri, 31 Mar 2023 03:45:52 +0200 Subject: [PATCH 12/13] Changes for eval (#179) * Update evaluation.py * Add prompting baseline * fixed ppl * Update metadata_utils.py * Update evaluation.py --------- Co-authored-by: Paul Pommer Co-authored-by: Jonathan Chang <31893406+cccntu@users.noreply.github.com> --- bsmetadata/evaluation.py | 217 ++++++++++++++++++++++++++++------- bsmetadata/metadata_utils.py | 7 +- 2 files changed, 177 insertions(+), 47 deletions(-) diff --git a/bsmetadata/evaluation.py b/bsmetadata/evaluation.py index fcc40e0a..2da18b1b 100644 --- a/bsmetadata/evaluation.py +++ b/bsmetadata/evaluation.py @@ -1,8 +1,9 @@ +# %%writefile bsmetadata/evaluation.py import argparse import functools import itertools import json -from typing import Dict +from typing import Any, Dict, Optional import rich import torch @@ -12,11 +13,19 @@ from omegaconf import OmegaConf from rich.text import Text from tqdm.auto import tqdm - -from bsmetadata.metadata_utils import add_metadata_and_chunk_examples from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from bsmetadata import metadata_utils +from bsmetadata.metadata_processors import ( + PROCESSORS, + DatasourceProcessor, + GenerationLengthProcessor, + MetadataConfig, + MetadataProcessor, +) +from bsmetadata.metadata_utils import add_metadata_and_chunk_examples, convert_v2_dataset_to_v1_format_v1_compatible + def format_by_one_mask(input_ids, mask, tokenizer): i = 0 @@ -31,8 +40,8 @@ def format_by_one_mask(input_ids, mask, tokenizer): @torch.no_grad() def ppl_fn( - batch: Dict[str, torch.Tensor], - outputs: CausalLMOutputWithCrossAttentions, + batch: Dict[str, torch.Tensor], + outputs: CausalLMOutputWithCrossAttentions, metadata_mask: torch.Tensor = None, save_data: bool = False, idx: int = None, @@ -60,12 +69,12 @@ def ppl_fn( metadata_mask = metadata_mask.bool() nonmetadata_cumsum = torch.cumsum(~metadata_mask, dim=-1) first_nonmetadata = nonmetadata_cumsum == 1 - rich.print(f"{~(metadata_mask.bool())=}") - rich.print("(attention_mask.bool())") - rich.print(attention_mask.bool()) + # rich.print(f"{~(metadata_mask.bool())=}") + # rich.print("(attention_mask.bool())") + # rich.print(attention_mask.bool()) loss_mask = torch.logical_and(attention_mask.bool(), ~(metadata_mask.bool())) loss_mask = torch.logical_and(loss_mask, ~first_nonmetadata) - rich.print(f"{loss_mask=}") + # rich.print(f"{loss_mask=}") else: loss_mask = attention_mask.bool() shift_mask = loss_mask[..., 1:].contiguous() @@ -110,8 +119,8 @@ def ppl_fn( # if metadata_mask is not None: # shift_metadata_mask = metadata_mask[..., 1:].contiguous().bool() # shift_mask = torch.logical_and(shift_mask, ~shift_metadata_mask) - rich.print(f"shift_mask{shift_mask}") - rich.print(f"{shift_mask.sum()=}") + # rich.print(f"shift_mask{shift_mask}") + # rich.print(f"{shift_mask.sum()=}") # Flatten the tokens loss = F.cross_entropy( @@ -132,20 +141,36 @@ def ppl_fn( f"{idx}_loss{suffix}.pt", ) - loss = loss.cpu().squeeze().numpy().tolist() - shift_mask = shift_mask.cpu().squeeze().numpy().tolist() + # loss = loss.cpu().squeeze().numpy().tolist() + # shift_mask = shift_mask.cpu().squeeze().numpy().tolist() + # return loss, shift_mask, shift_labels.cpu().squeeze().numpy().tolist() + # return loss, shift_mask + + if save_data: + # Save the non-masked tokens & their loss + suffix = "_meta" if metadata_mask is not None else "" + torch.save( + { + 'loss': loss, + 'shift_mask':shift_mask, + 'input_ids': batch['input_ids'], + 'attention_mask': attention_mask, + 'metadata_mask': metadata_mask, + } + , + + f"{idx}_data{suffix}.pt", + ) - return loss, shift_mask, shift_labels.cpu().squeeze().numpy().tolist() - return loss, shift_mask # Normalize to avoid an overflow when there are many tokens normed_loss_weights = shift_mask / shift_mask.sum() loss = (loss * normed_loss_weights).sum() # Per-example ppl - ppl = torch.exp((loss * shift_mask).sum(-1) / shift_mask.sum(-1)) + #ppl = torch.exp((loss * shift_mask).sum(-1) / shift_mask.sum(-1)) - return ppl + return loss, shift_mask.sum() @torch.no_grad() @@ -174,6 +199,69 @@ def get_ppl( return ppl +def datasource_process_global_for_prompt(self, metadata_attrs: Dict[str, Any]) -> Optional[str]: + # We represent the DATASOURCE by using meaningful information of the URL. + # URL: http://www.example.de/2015/forum/article/21-new-project + # Example: example.de > forum > article > new project + return "".join(["Data source", self.cfg.metadata_key_value_sep, metadata_attrs["value"]]) + + +def generation_length_process_global_for_prompt(self, metadata_attrs: Dict[str, Any]) -> Optional[str]: + # We represent the length of a text by the number of characters. + # Example: Length: 123 + return "".join(["Number of characters", self.cfg.metadata_key_value_sep, metadata_attrs["value"]]) + + +def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str: + """Creates a prompt containing all global metadata information (including URLs, timestamps, etc) + and/or local metadata special tokens + Args: + example: The example to create a metadata prefix for. + cfg: The data config to use. + Returns: + A string containing the metadata prefix. + """ + LIST_LIKE_METADATA_PROMPT_FIELDS = { + "entity": "Entities", + "entity_paragraph": "Entity Paragraphs", + } + example = convert_v2_dataset_to_v1_format_v1_compatible(example=example) + processed_metadata = {} + for metadata in example["metadata"]: + key, type_ = metadata["key"], metadata["type"] + if key not in cfg.metadata_list: + # rich.print(f"metadata key not in metadata_list, skipping. {key}, {cfg.metadata_list}") + continue + + if type_ == "global": + processor = PROCESSORS.get(key, MetadataProcessor)(cfg) + processed_metadata[key] = processor.process_global(metadata) + elif key in LIST_LIKE_METADATA_PROMPT_FIELDS: + if key not in processed_metadata: + processed_metadata[key] = set() # Same entities may occurr at different positions + processed_metadata[key].add(metadata["value"]) + elif ( + cfg.add_local_metadata_special_tokens_in_prefix + and cfg.local_metadata_special_tokens + and key in cfg.local_metadata_special_tokens + ): + processed_metadata[key] = cfg.local_metadata_special_tokens[key] + else: + processed_metadata[key] = key.title() + + for list_like_metadata in LIST_LIKE_METADATA_PROMPT_FIELDS: + if list_like_metadata in processed_metadata: + processed_metadata[list_like_metadata] = ( + LIST_LIKE_METADATA_PROMPT_FIELDS[list_like_metadata] + + cfg.metadata_key_value_sep + + ", ".join(v.replace("_", " ") for v in processed_metadata[list_like_metadata]) + ) + + sorted_metadata = [processed_metadata.get(md, None) for md in cfg.metadata_list] + sorted_metadata = [md for md in sorted_metadata if md is not None] + return cfg.metadata_sep.join(sorted_metadata) + cfg.metadata_prefix_sep if sorted_metadata else "" + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -218,6 +306,11 @@ def get_ppl( action="store_true", help="If set to true, will load gpt2-xl", ) + parser.add_argument( + "--prompt", + action="store_true", + help="If set to true, the script evaluates metadata in prompt style", + ) args = parser.parse_args() print(f"Parameters: {args}") @@ -252,10 +345,18 @@ def get_ppl( cfg.metadata_probability = 1.0 cfg.entity_setting = "beg" cfg.metadata_list.append("entity") + cfg.metadata_list.append("paragraph") + + if args.prompt: + cfg.metadata_sep = "; " # Instead of " | " + cfg.metadata_prefix_sep = "" # Instead of " |||"; there's already an implicit " " + DatasourceProcessor.process_global = datasource_process_global_for_prompt + GenerationLengthProcessor.process_global = generation_length_process_global_for_prompt + metadata_utils.create_metadata_prefix = create_metadata_prompt + preprocess_fn = functools.partial(add_metadata_and_chunk_examples, tokenizer=tokenizer, cfg=cfg) # Validation datasets - dataset_paths = [ "bs-modeling-metadata/c4-en-html-with-validation_metadata_html", "bs-modeling-metadata/c4-en-html-with-validation_metadata_entity", @@ -273,10 +374,10 @@ def get_ppl( for path in dataset_paths: n_examples = 0 - total_normal_len = 0 - total_normal_ppl = 0 - total_metadata_len = 0 - total_metadata_ppl = 0 + total_normal_len = [] + total_normal_ppl = [] + total_metadata_len = [] + total_metadata_ppl = [] exit_flag = False # Load validation dataset from hugging face @@ -285,7 +386,10 @@ def get_ppl( split = "validation" if not args.test else "validation[:10]" validation_dataset = load_dataset(path, use_auth_token=True, split=split) + data = [] for idx, example in tqdm(enumerate(validation_dataset), desc=f"Calculating perplexity for {metadata_type}..."): + #for idx in [136,]: + example = validation_dataset[idx] # Preprocess examples examples = {k: [v] for k, v in example.items()} try: @@ -307,6 +411,10 @@ def get_ppl( # rich.print(f"{normal_example['attention_mask']=}") # import sys # sys.exit() + # print(metadata_example) + if "input_ids" not in metadata_example: + print("Skipping") + continue metadata_example_len = len(metadata_example["input_ids"]) min_seq_len = min(normal_example_len, metadata_example_len) max_seq_len = max(normal_example_len, metadata_example_len) @@ -316,9 +424,9 @@ def get_ppl( # 2) examples fitting the model sequence length if len(processed_examples["input_ids"]) == 1 and min_seq_len > 0 and max_seq_len <= 1024: # Keep track of considered examples and total length + if n_examples % 10 == 0: + print("n_examples completed.") n_examples += 1 - total_normal_len += normal_example_len - total_metadata_len += metadata_example_len # Prepare batches normal_example["labels"] = normal_example["input_ids"] @@ -330,48 +438,56 @@ def get_ppl( metadata_batch = {k: v.cuda() for k, v in metadata_batch.items()} if n_examples == 1: ex = format_by_one_mask(normal_batch["input_ids"][0], normal_batch["attention_mask"][0], tokenizer) - rich.print(f"Normal example:") - rich.print(ex) + # rich.print(f"Normal example:") + # rich.print(ex) ex = format_by_one_mask( metadata_batch["input_ids"][0], metadata_batch["metadata_mask"][0], tokenizer ) - rich.print(f"Metadata example:") - rich.print(ex) - rich.print(tokenizer.decode(metadata_batch["input_ids"][0])) + # rich.print(f"Metadata example:") + # rich.print(ex) + # rich.print(tokenizer.decode(metadata_batch["input_ids"][0])) # Calculate ppl - normal_ppl = get_ppl(normal_batch, save_data=args.save_data, idx=idx) - # total_normal_ppl += float(normal_ppl) * normal_example_len - metadata_ppl = get_ppl(metadata_batch, save_data=args.save_data, idx=idx) - # total_metadata_ppl += float(metadata_ppl) * metadata_example_len - if n_examples == 1: + normal_ppl, normal_example_len = get_ppl(normal_batch, save_data=args.save_data, idx=idx) # [0] + # print("PPL") + # print(normal_ppl) + total_normal_ppl.append(normal_ppl)# * normal_example_len + metadata_ppl, metadata_example_len = get_ppl(metadata_batch, save_data=args.save_data, idx=idx) # [0] + # print(metadata_ppl) + total_metadata_ppl.append(metadata_ppl)# * metadata_example_len + + total_normal_len.append(normal_example_len) + total_metadata_len.append(metadata_example_len) + + data.append({'idx':idx,'normal_ppl':normal_ppl, 'metadata_ppl':metadata_ppl}) + if False: # n_examples == 1: loss, mask, shift_labels = normal_ppl - print("normal ppl") + # print("normal ppl") printed = 0 for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)): if m: if printed < 10: - rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}") + # rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}") printed += 1 unmasked_labels = [label for label, m in zip(shift_labels, mask) if m] # print(f"first 10 unmasked labels: {[tokenizer.decode(x) for x in unmasked_labels[:10]]}") - print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}") + # print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}") # ex = format_by_one_mask(normal_batch["input_ids"][0], mask, tokenizer) # rich.print(ex) loss, mask, shift_labels = metadata_ppl printed = 0 - print("metadata ppl") + # print("metadata ppl") for i, (l, m, sl) in enumerate(zip(loss, mask, shift_labels)): if m: if printed < 10: - rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}") + # rich.print(f"Loss {json.dumps(tokenizer.decode(sl))}: {l}") printed += 1 unmasked_labels = [label for label, m in zip(shift_labels, mask) if m] - print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}") + # print(f"first 10 unmasked labels: {tokenizer.decode(unmasked_labels[:10])}") # ex = format_by_one_mask(metadata_batch["input_ids"][0], mask, tokenizer) # rich.print(ex) @@ -386,9 +502,9 @@ def get_ppl( # rich.print(f"Metadata example: (ppl={metadata_ppl[0]})") # rich.print(f"Normal example: (mask={normal_ppl[1]})") # rich.print(f"Metadata example: (mask={metadata_ppl[1]})") - import sys + # import sys - sys.exit() + # sys.exit() if n_examples > 1000: break @@ -402,8 +518,20 @@ def get_ppl( # Get average ppl weighted by token sequence length if n_examples > 0: - final_normal_ppl = total_normal_ppl / total_normal_len - final_metadata_ppl = total_metadata_ppl / total_metadata_len + def ppl(examples_mean_loss, examples_len): + examples_mean_loss = torch.tensor(examples_mean_loss) + examples_len = torch.tensor(examples_len) + weight = examples_len / examples_len.sum() + return torch.exp2((examples_mean_loss * weight).sum()).item() + + torch.save({ + 'total_normal_ppl': total_normal_ppl, + 'total_metadata_ppl': total_metadata_ppl, + 'total_normal_len': total_normal_len, + 'total_metadata_len': total_metadata_len, + }, 'eva.data2') + final_normal_ppl = ppl(total_normal_ppl, total_normal_len) + final_metadata_ppl = ppl(total_metadata_ppl, total_metadata_len) else: final_metadata_ppl = final_normal_ppl = 0 @@ -412,3 +540,4 @@ def get_ppl( f.write(f"=== RESULT [{metadata_type}] ===\n") f.write("Perplexity (metadata): {:>6,.3f}\n".format(final_metadata_ppl)) f.write("Perplexity (normal): {:>6,.3f}\n\n".format(final_normal_ppl)) + torch.save(data, 'eva.data') diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index 5f433e65..63b2b18e 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -1,3 +1,4 @@ +# %%writefile bsmetadata/metadata_utils.py # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -99,9 +100,9 @@ def add_metadata_and_chunk_examples( text_with_local_metadata = example["text"] char_level_metadata_mask = [False] * len(text_with_local_metadata) - if metadata_prefix_encoded: - text_with_local_metadata = "" + text_with_local_metadata - char_level_metadata_mask = [False] + char_level_metadata_mask + # if metadata_prefix_encoded: + # text_with_local_metadata = "" + text_with_local_metadata + # char_level_metadata_mask = [False] + char_level_metadata_mask text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata) From 23b12210bb19b52449afb29dfaa7d743869514c9 Mon Sep 17 00:00:00 2001 From: Jonathan Chang <31893406+cccntu@users.noreply.github.com> Date: Fri, 7 Apr 2023 21:29:18 +0800 Subject: [PATCH 13/13] Change exp2 to exp, increase examples limit --- bsmetadata/evaluation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bsmetadata/evaluation.py b/bsmetadata/evaluation.py index 2da18b1b..2692e9fc 100644 --- a/bsmetadata/evaluation.py +++ b/bsmetadata/evaluation.py @@ -506,7 +506,7 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str: # sys.exit() - if n_examples > 1000: + if n_examples > 2000: break if exit_flag: @@ -522,7 +522,7 @@ def ppl(examples_mean_loss, examples_len): examples_mean_loss = torch.tensor(examples_mean_loss) examples_len = torch.tensor(examples_len) weight = examples_len / examples_len.sum() - return torch.exp2((examples_mean_loss * weight).sum()).item() + return torch.exp((examples_mean_loss * weight).sum()).item() torch.save({ 'total_normal_ppl': total_normal_ppl,