diff --git a/bsmetadata/deepspeed_configs/v2.json b/bsmetadata/deepspeed_configs/v2.json
index 1d5c0311..dca70cc9 100644
--- a/bsmetadata/deepspeed_configs/v2.json
+++ b/bsmetadata/deepspeed_configs/v2.json
@@ -30,19 +30,19 @@
}
},
"zero_optimization": {
- "stage": 1,
- "allgather_partitions": true,
- "allgather_bucket_size": 500000000,
- "overlap_comm": true,
- "reduce_scatter": true,
- "reduce_bucket_size": 500000000,
- "contiguous_gradients": true,
- "cpu_offload": true
- },
- "gradient_accumulation_steps": 16,
+ "stage": 2,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": true,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": true,
+ "cpu_offload": false
+},
+ "gradient_accumulation_steps": 2,
"gradient_clipping": "auto",
"steps_per_print": 100,
- "train_batch_size": 256,
+ "train_batch_size": 512,
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
\ No newline at end of file
diff --git a/bsmetadata/evaluation.py b/bsmetadata/evaluation.py
index b75c2aa5..b14dca92 100644
--- a/bsmetadata/evaluation.py
+++ b/bsmetadata/evaluation.py
@@ -46,6 +46,7 @@ def mean_loss_fn(
metadata_mask: torch.Tensor = None,
save_data: bool = False,
idx: int = None,
+ tokenizer = None,
) -> torch.Tensor:
"""Calculates the perplexity for a given batch.
@@ -62,15 +63,15 @@ def mean_loss_fn(
b = outputs.logits.size(0)
lm_logits = outputs.logits
- lm_logits[:, :, 50257] = float("-inf")
- lm_logits[:, :, 50258] = float("-inf")
-
labels = batch["labels"]
attention_mask = batch["attention_mask"]
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
if metadata_mask is not None:
+ for special_tok in tokenizer.additional_special_tokens_ids:
+ shift_logits[:, :, special_tok] = torch.finfo(lm_logits.dtype).min
+
metadata_mask = metadata_mask.bool()
nonmetadata_cumsum = torch.cumsum(~metadata_mask, dim=-1)
first_nonmetadata = nonmetadata_cumsum == 1
@@ -133,7 +134,7 @@ def mean_loss_fn(
shift_labels.view(-1),
reduction="none",
).view(b, -1)
-
+ loss = torch.nan_to_num(loss)
if save_data:
# Save the non-masked tokens & their loss
suffix = "_meta" if metadata_mask is not None else ""
@@ -180,6 +181,8 @@ def get_mean_loss(
batch: Dict[str, torch.Tensor],
save_data: bool = False,
idx: int = None,
+ model=None,
+ tokenizer = None,
) -> torch.Tensor:
"""Prepares the arguments for perplexity calculation and passes them to the perplexity function.
@@ -197,7 +200,7 @@ def get_mean_loss(
metadata_mask = batch.pop("metadata_mask", None)
outputs = model(**batch)
batch["labels"] = labels
- nll = mean_loss_fn(batch, outputs, metadata_mask, save_data=save_data, idx=idx)
+ nll = mean_loss_fn(batch, outputs, metadata_mask, save_data=save_data, idx=idx,tokenizer=tokenizer)
return nll
@@ -264,75 +267,28 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
return cfg.metadata_sep.join(sorted_metadata) + cfg.metadata_prefix_sep if sorted_metadata else ""
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--repo_id",
- type=str,
- default="bs-modeling-metadata/checkpoints_all_04_23",
- help="Repository ID for the model to compute perplexity for",
- )
- parser.add_argument(
- "--subfolder",
- type=str,
- default="checkpoint-2000step",
- help="subfolder in the respository with the specific checkpoint to evaluate perplexity for",
- )
- parser.add_argument(
- "--config_file_path",
- type=str,
- help="The path actual_config.yaml if available, otherwise repo_id/actual_config.yaml or git clone's v2.yaml",
- )
- parser.add_argument(
- "--output_file", type=str, default="evaluation.txt", help="Path to the file the perplexity is written to"
- )
- parser.add_argument("--no_cuda", action="store_true", help="If set to true, all computations are performed on CPU")
- parser.add_argument(
- "--save_data",
- action="store_true",
- help="If set to true, save tokens & losses",
- )
- parser.add_argument(
- "--test",
- action="store_true",
- help="If set to true, the script runs in test mode and only takes 10 examples per dataset",
- )
- parser.add_argument(
- "--max_n_examples",
- type=int,
- default=1500,
- help="how many examples per metadata type to evaluate",
- )
- parser.add_argument(
- "--metadata_to_test",
- type=str,
- default="html,entity,entity_paragraph,website_desc,generation_datasource,timestamp,title,generation_length_sentence,generation_length_text,url,paragraph",
- help="metadata types to test",
- )
- parser.add_argument(
- "--untrained",
- action="store_true",
- help="If set to true, will load gpt2-xl",
- )
- parser.add_argument(
- "--prompt",
- action="store_true",
- help="If set to true, the script evaluates metadata in prompt style",
- )
-
- args = parser.parse_args()
- print(f"Parameters: {args}")
-
- # Load config
- if args.config_file_path:
- config_file_path = args.config_file_path
- else:
- try:
- config_file_path = hf_hub_download(
- repo_id=args.repo_id, filename="actual_config.yaml", use_auth_token=True
- )
- except Exception:
- config_file_path = "bsmetadata/hydra_configs/v2.yaml"
+def evaluate_main(
+ metadata_to_test: str = "title,html,entity_paragraph,website_desc,generation_datasource,timestamp",
+ output_file: str = "evaluation.txt",
+ repo_id: str = None,
+ subfolder: str = None,
+ test: bool = False,
+ max_n_examples: int = 1500,
+ prompt: bool = False,
+ no_cuda: bool = True,
+ save_data: bool = False,
+ untrained: bool = False,
+ config_file_path: str = None,
+ model: str = None,
+ tokenizer: str = None,
+ accelerator=None,
+) -> dict:
+ # if config_file_path is None:
+ # try:
+ # config_file_path = hf_hub_download(repo_id=repo_id, filename="actual_config.yaml", use_auth_token=True)
+ # except Exception:
+ # config_file_path = "bsmetadata/hydra_configs/v2.yaml"
+ config_file_path = "/fsx/home-jordiclive/metadata/bsmetadata/hydra_configs/v2.yaml" #need to add this path to PYTHONPATH
repo_args = OmegaConf.load(config_file_path)
data_config = repo_args.data_config
@@ -341,15 +297,17 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
# Load model
print("Loading model...")
- if args.untrained:
- model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
- else:
- model = AutoModelForCausalLM.from_pretrained(args.repo_id, subfolder=args.subfolder, use_auth_token=True)
- model.eval().cuda() if not args.no_cuda else model.eval()
-
- # Load tokenizer
- tokenizer = AutoTokenizer.from_pretrained(repo_args.model_name)
- tokenizer.pad_token = tokenizer.eos_token
+ if model is None or tokenizer is None:
+ if untrained:
+ model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
+ tokenizer = AutoTokenizer.from_pretrained(repo_args.model_name)
+ tokenizer.pad_token = tokenizer.eos_token
+ else:
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, use_auth_token=True)
+ tokenizer = AutoTokenizer.from_pretrained(
+ "bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", use_auth_token=True
+ )
+ model.eval().cuda() if not no_cuda else model.eval()
# Config preprocess function
cfg = data_config.metadata_config
@@ -358,7 +316,7 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
cfg.metadata_list.append("entity")
cfg.metadata_list.append("paragraph")
- if args.prompt:
+ if prompt:
cfg.metadata_sep = "; " # Instead of " | "
cfg.metadata_prefix_sep = "" # Instead of " |||"; there's already an implicit " "
DatasourceProcessor.process_global = datasource_process_global_for_prompt
@@ -381,8 +339,8 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
"bs-modeling-metadata/c4-en-html-with-validation_metadata_url",
"bs-modeling-metadata/c4-en-html-with-validation_metadata_paragraph",
]
- dataset_paths = [path for path in dataset_paths if path.split("_metadata_")[1] in args.metadata_to_test.split(",")]
-
+ dataset_paths = [path for path in dataset_paths if path.split("_metadata_")[1] in metadata_to_test.split(",")]
+ results = {}
for path in dataset_paths:
n_examples = 0
total_normal_len = []
@@ -394,11 +352,11 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
# Load validation dataset from hugging face
metadata_type = path.split("_metadata_")[1]
print(f"Loading {metadata_type} data...")
- split = "validation" if not args.test else "validation[:10]"
+ split = "validation" if not test else "validation[:10]"
validation_dataset = load_dataset(path, use_auth_token=True, split=split)
data = []
- max_n_examples_ord = len(str(args.max_n_examples))
+ max_n_examples_ord = len(str(max_n_examples))
for idx, example in tqdm(enumerate(validation_dataset), desc=f"Calculating perplexity for {metadata_type}..."):
# for idx in [136,]:
example = validation_dataset[idx]
@@ -409,7 +367,7 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
except Exception as e:
# Write error to output file and continue with next dataset
print(e)
- with open(args.output_file, "a", encoding="utf8") as f:
+ with open(output_file, "a", encoding="utf8") as f:
f.write(f"=== RESULT [{metadata_type}] ===\n")
f.write(f"{e}\n\n")
exit_flag = True
@@ -445,7 +403,10 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
normal_batch = default_data_collator([normal_example])
metadata_example["labels"] = metadata_example["input_ids"]
metadata_batch = default_data_collator([metadata_example])
- if not args.no_cuda:
+ if accelerator is not None:
+ normal_batch = {k: v.to(accelerator.device) for k, v in normal_batch.items()}
+ metadata_batch = {k: v.to(accelerator.device) for k, v in metadata_batch.items()}
+ elif not no_cuda:
normal_batch = {k: v.cuda() for k, v in normal_batch.items()}
metadata_batch = {k: v.cuda() for k, v in metadata_batch.items()}
if n_examples == 1:
@@ -461,12 +422,14 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
# rich.print(tokenizer.decode(metadata_batch["input_ids"][0]))
# Calculate nll (natural-log loss)
- normal_nll, normal_example_len = get_mean_loss(normal_batch, save_data=args.save_data, idx=idx) # [0]
+ normal_nll, normal_example_len = get_mean_loss(
+ normal_batch, save_data=save_data, idx=idx, model=model,tokenizer=tokenizer
+ ) # [0]
# print("PPL")
# print(normal_ppl)
total_normal_nll.append(normal_nll) # * normal_example_len
metadata_nll, metadata_example_len = get_mean_loss(
- metadata_batch, save_data=args.save_data, idx=idx
+ metadata_batch, save_data=save_data, idx=idx, model=model,tokenizer=tokenizer
) # [0]
# print(metadata_ppl)
total_metadata_nll.append(metadata_nll) # * metadata_example_len
@@ -521,7 +484,7 @@ def create_metadata_prompt(example: Dict[str, Any], cfg: MetadataConfig) -> str:
# sys.exit()
- if n_examples > args.max_n_examples:
+ if n_examples > max_n_examples:
break
if exit_flag:
@@ -554,9 +517,86 @@ def ppl(examples_mean_loss, examples_len):
else:
final_metadata_ppl = final_normal_ppl = 0
- # Write results to output file
- with open(args.output_file, "a", encoding="utf8") as f:
- f.write(f"=== RESULT [{metadata_type}] ===\n")
- f.write("Perplexity (metadata): {:>6,.3f}\n".format(final_metadata_ppl))
- f.write("Perplexity (normal): {:>6,.3f}\n\n".format(final_normal_ppl))
+ results[metadata_type] = {"final_normal_ppl": final_normal_ppl, "final_metadata_ppl": final_metadata_ppl}
torch.save(data, "eva.data")
+ return results
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--repo_id",
+ type=str,
+ default="bs-modeling-metadata/checkpoints_all_04_23",
+ help="Repository ID for the model to compute perplexity for",
+ )
+ parser.add_argument(
+ "--subfolder",
+ type=str,
+ default="checkpoint-2000step",
+ help="subfolder in the respository with the specific checkpoint to evaluate perplexity for",
+ )
+ parser.add_argument(
+ "--config_file_path",
+ type=str,
+ help="The path actual_config.yaml if available, otherwise repo_id/actual_config.yaml or git clone's v2.yaml",
+ )
+ parser.add_argument(
+ "--output_file", type=str, default="evaluation.txt", help="Path to the file the perplexity is written to"
+ )
+ parser.add_argument("--no_cuda", action="store_true", help="If set to true, all computations are performed on CPU")
+ parser.add_argument(
+ "--save_data",
+ action="store_true",
+ help="If set to true, save tokens & losses",
+ )
+ parser.add_argument(
+ "--test",
+ action="store_true",
+ help="If set to true, the script runs in test mode and only takes 10 examples per dataset",
+ )
+ parser.add_argument(
+ "--max_n_examples",
+ type=int,
+ default=1500,
+ help="how many examples per metadata type to evaluate",
+ )
+ parser.add_argument(
+ "--metadata_to_test",
+ type=str,
+ default="html,entity,entity_paragraph,website_desc,generation_datasource,timestamp,title,generation_length_sentence,generation_length_text,url,paragraph",
+ help="metadata types to test",
+ )
+ parser.add_argument(
+ "--untrained",
+ action="store_true",
+ help="If set to true, will load gpt2-xl",
+ )
+ parser.add_argument(
+ "--prompt",
+ action="store_true",
+ help="If set to true, the script evaluates metadata in prompt style",
+ )
+
+ args = parser.parse_args()
+ print(f"Parameters: {args}")
+ results = evaluate_main(
+ repo_id=args.repo_id,
+ subfolder=args.subfolder,
+ config_file_path=args.config_file_path,
+ output_file=args.output_file,
+ save_data=args.save_data,
+ test=args.test,
+ max_n_examples=args.max_n_examples,
+ metadata_to_test=args.metadata_to_test,
+ untrained=args.untrained,
+ prompt=args.prompt,
+ no_cuda=args.no_cuda,
+ )
+ # Load config
+ # Write results to output file
+ with open(args.output_file, "a", encoding="utf8") as f:
+ for k, v in results.items():
+ f.write(f"=== RESULT [{k}] ===\n")
+ f.write("Perplexity (metadata): {:>6,.3f}\n".format(v["final_metadata_ppl"]))
+ f.write("Perplexity (normal): {:>6,.3f}\n\n".format(v["final_normal_ppl"]))
diff --git a/bsmetadata/experiments/with_metadata_datasetv2_tf.py b/bsmetadata/experiments/with_metadata_datasetv2_tf.py
index e4ebdd28..1f45ad4a 100644
--- a/bsmetadata/experiments/with_metadata_datasetv2_tf.py
+++ b/bsmetadata/experiments/with_metadata_datasetv2_tf.py
@@ -87,7 +87,7 @@ def filter_empty(t):
return data
-def get_dataloader(*, tokenizer, args, num_gpus, gpu_id):
+def get_dataloader(*, tokenizer, args, num_gpus, gpu_id, train=True):
"""returns a tensorflow dataloader"""
data_config = args
local_dir = Path(data_config.dataset_name)
@@ -104,14 +104,22 @@ def get_dataloader(*, tokenizer, args, num_gpus, gpu_id):
print(f"{len(files_with_entities)} files with entities")
print(f"{len(files_without_entities)} files without entities")
+ if train:
+ files_with_entities = [
+ x for x in files_with_entities if "c4-en-html_cc-main-2019-18_pq00-000.jsonl.gz" not in x.name
+ ]
+ else:
+ files_with_entities = [
+ x for x in files_with_entities if "c4-en-html_cc-main-2019-18_pq00-000.jsonl.gz" in x.name
+ ]
+
data_with_entities = get_dataset(files_with_entities, num_gpus, gpu_id, data_config, tokenizer)
- data_without_entities = get_dataset(files_without_entities, num_gpus, gpu_id, data_config, tokenizer)
+
data = tf.data.Dataset.sample_from_datasets(
- [data_with_entities, data_without_entities],
- weights=[float(len(files_with_entities)), float(len(files_without_entities))],
+ [data_with_entities],
+ weights=[float(len(files_with_entities))],
seed=42,
)
-
data = data.shuffle(1000, reshuffle_each_iteration=True)
data = data.batch(data_config.per_device_train_batch_size)
data = data.prefetch(tf.data.AUTOTUNE)
diff --git a/bsmetadata/hydra_configs/v2.yaml b/bsmetadata/hydra_configs/v2.yaml
index 42a0044a..805d593f 100644
--- a/bsmetadata/hydra_configs/v2.yaml
+++ b/bsmetadata/hydra_configs/v2.yaml
@@ -11,7 +11,7 @@ data_config:
title: 1.0657717366883845
generation_datasource: 1.0
entity_paragraph: 1.028817740667444
-
+ generation_length_text: 1.0
#- url: 1.0
#- generation_length_sentence
#- generation_length_text
@@ -28,6 +28,7 @@ data_config:
- datasource
- length
- entity_paragraph
+ - generation_length_text
metadata_column_list:
- html
- timestamp
@@ -36,9 +37,9 @@ data_config:
#- url
- generation_datasource
#- generation_length_sentence
- #- generation_length_text
+ - generation_length_text
- entity_paragraph
- local_metadata_special_tokens:
+ local_metadata_special_tokens:
entity_paragraph: "entity"
metadata_sep: ' | '
metadata_key_value_sep: ': '
@@ -72,14 +73,16 @@ data_config:
- 0.0
local_metadata_special_token_start:
entity_paragraph: ""
+ html: ""
local_metadata_special_token_end:
entity_paragraph: " "
+ html: ""
local_metadata_special_token_state: true
- html_overall_sample_rate: 0.25
+ html_overall_sample_rate: 0.5
without_metadata_same_context: false
experiment: with_metadata_datasetv2_tf
- per_device_eval_batch_size: 8
- per_device_train_batch_size: 8
+ per_device_eval_batch_size: 32 # 32 for 40gb
+ per_device_train_batch_size: 32
dataset_name: bs-modeling-metadata/c4-en-html-with-training_metadata_all
dataset_config_name: null
train_file: c4-en-html_cc-main-2019-18_pq00-000.jsonl.gz
@@ -87,12 +90,13 @@ data_config:
overwrite_cache: false
cache_dir: null
extension: null
- preprocessing_num_workers: 6
+ preprocessing_num_workers: 40
validation_split_percentage: 5
block_size: null
map_batch_size: 1
weight_decay: 0.01
-learning_rate: 5e-5
+learning_rate: 0.0001
+wb_name: "all_metadata"
num_train_epochs: 1
max_train_steps: 100000
lr_scheduler_type: linear
@@ -103,16 +107,16 @@ model_name: gpt2
project_name: metadata_lm
jobid: ''
start_with_eval: false
-extra_steps_to_eval_save_at:
-- 2
+#extra_steps_to_eval_save_at:
+#- 2
evaluation_strategy: STEPS
eval_num_per_epoch: 3
-eval_steps: 2000
+eval_steps: 1000
save_strategy: STEPS
save_num_per_epoch: 3
-save_steps: 150
+save_steps: 500
do_train: true
do_eval: true
gradient_checkpointing: true
resume_from_checkpoint_dir: null
-gradient_accumulation_steps: 16
+gradient_accumulation_steps: 2
\ No newline at end of file
diff --git a/bsmetadata/train.py b/bsmetadata/train.py
index d97853b6..4dc55a2c 100644
--- a/bsmetadata/train.py
+++ b/bsmetadata/train.py
@@ -16,11 +16,11 @@
import wandb
from accelerate import Accelerator
from accelerate.utils import DistributedType, DummyOptim, DummyScheduler
+from evaluation import evaluate_main
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
-from torch.optim import AdamW
from tqdm.auto import tqdm as original_tqdm
-from transformers import AddedToken, AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler, set_seed
+from transformers import AdamW, AddedToken, AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler, set_seed
from transformers.trainer_utils import IntervalStrategy
from bsmetadata.input_pipeline import DataConfig, get_dataloaders
@@ -34,6 +34,7 @@ class CFG:
data_config: DataConfig = DataConfig()
weight_decay: float = field(default=0.0, metadata={"help": "The weight decay to use for training."})
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate."})
+ wb_name: str = field(default="bsmetadata", metadata={"help": "The name of the wandb project."})
gradient_accumulation_steps: int = field(
default=1,
metadata={"help": "The number of gradient accumulation steps to perform before updating model parameters."},
@@ -217,8 +218,8 @@ def main(args: CFG) -> None:
is_local_main_process = accelerator.is_local_main_process
tqdm = partial(original_tqdm, disable=not is_local_main_process, position=0)
use_deepspeed = accelerator.state.deepspeed_plugin is not None
- use_deepspeed_optimzer = use_deepspeed and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
- use_deepspeed_scheduler = use_deepspeed and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
+ use_deepspeed_optimzer = use_deepspeed or "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
+ use_deepspeed_scheduler = use_deepspeed or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
if accelerator.distributed_type == DistributedType.DEEPSPEED and not use_deepspeed_scheduler:
assert False, "Please set scheduler in DeepSpeed config file otherwise it may not be checkpointed properly"
@@ -294,7 +295,13 @@ def main(args: CFG) -> None:
gpu_id=accelerator.process_index,
)
dummy_dataloader = get_dummy_dataloader(args.data_config.per_device_train_batch_size)
- eval_dataloaders = dict()
+ eval_dataloader, format_fn_eval = get_dataloader(
+ tokenizer=tokenizer,
+ args=args.data_config,
+ num_gpus=accelerator.num_processes,
+ gpu_id=accelerator.process_index,
+ train=False,
+ )
model, optimizer, dummy_dataloader, scheduler = accelerator.prepare(
model, optimizer, dummy_dataloader, scheduler
)
@@ -348,7 +355,7 @@ def format_fn(x):
save_per_n_step = args.max_train_steps + 1 # will never eval
@torch.no_grad()
- def evaluate(eval_dataloader):
+ def evaluate(eval_dataloader, only_first_n_steps=120):
model.eval()
losses = []
for step, batch in enumerate(tqdm(eval_dataloader, desc="eval")): # , leave=False)
@@ -359,7 +366,8 @@ def evaluate(eval_dataloader):
loss = loss_fn(batch, outputs, metadata_mask)
losses.append(accelerator.gather(loss.repeat(args.data_config.per_device_eval_batch_size)))
-
+ if step == only_first_n_steps:
+ break
model.train()
if not losses:
# in case the dataloader is empty
@@ -368,23 +376,39 @@ def evaluate(eval_dataloader):
perplexity = math.exp(torch.mean(losses))
return {"perplexity": perplexity}
- def evaluate_multiple_dateloaders(eval_dataloaders):
- for key, eval_dataloader in eval_dataloaders.items():
- logger.info(f"Evaluating split {key}")
- metrics = evaluate(eval_dataloader)
- metrics_logger.log({key: metrics})
+ def evaluate_multiple_dateloaders(eval_dataloaders, use_full_evaluation_for_val):
+ if use_full_evaluation_for_val:
+ results = evaluate_main(
+ output_file="eval.txt",
+ # metadata_to_test="entity_paragraph",
+ metadata_to_test="title,html,entity_paragraph,website_desc,generation_datasource,timestamp,generation_length_text",
+ model=model,
+ tokenizer=tokenizer,
+ accelerator=accelerator,
+ )
+ model.train()
+ for k, v in results.items():
+ metrics_logger.log({k: v})
+ else:
+ for key, eval_dataloader in eval_dataloaders.items():
+ logger.info(f"Evaluating split {key}")
+ metrics = evaluate(eval_dataloader)
+ metrics_logger.log({key: metrics})
logger.info("Evaluation finished")
if not args.do_train and not args.do_eval:
return
progress_bar = tqdm(range(args.max_train_steps), desc="training", initial=train_state.completed_steps)
- metrics_logger = Logger(is_local_main_process, project=args.project_name, config=config_dict)
+ t_bs = args.data_config.per_device_train_batch_size * args.gradient_accumulation_steps * 8
+ os.environ['WANDB_API_KEY'] = 'd8216641d549f9bb3d0c5074baa39e15dfd55030'
+ metrics_logger = Logger(is_local_main_process, name=f"{args.wb_name}-{args.learning_rate}-{t_bs}",
+ entity='jordanclive', project='metadata', config=config_dict)
do_eval = args.do_eval and args.start_with_eval
if do_eval:
logger.info("Start with an evaluation")
- evaluate_multiple_dateloaders(eval_dataloaders)
+ evaluate_multiple_dateloaders(eval_dataloaders, use_full_evaluation_for_val=True)
if not args.do_train:
return
@@ -406,7 +430,7 @@ def save(path):
model.save_checkpoint(path)
else:
accelerator.save_state(path)
- save_model_and_tokenizer(accelerator, model, path)
+ save_model_and_tokenizer(accelerator, model, path, tokenizer=tokenizer)
if is_local_main_process:
train_state.save(path / "train_state.json")
@@ -426,6 +450,17 @@ def get_data_iter():
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
yield batch
+ def get_eval_data_iter():
+ while True:
+ for batch in eval_dataloader:
+ batch = format_fn_eval(batch)
+ if args.data_config.experiment == "with_metadata_datasetv2_tf":
+ batch = {k: v.to(accelerator.device) for k, v in batch.items()}
+ yield batch
+
+ eval_iter = get_eval_data_iter()
+ eval_dataloaders = {"validation": eval_iter}
+
data_iter = get_data_iter()
for _ in tqdm(
@@ -461,11 +496,18 @@ def get_data_iter():
optimizer.zero_grad()
step_loss_gathered = accelerator.gather(step_loss).mean().item()
- metrics = {
- "loss": step_loss_gathered,
- "lr": max(scheduler.get_lr()),
- "gradient_step": train_state.completed_steps,
- }
+ if step < 20:
+ metrics = {
+ "loss": step_loss_gathered,
+ "lr": 0,
+ "gradient_step": train_state.completed_steps,
+ }
+ else:
+ metrics = {
+ "loss": step_loss_gathered,
+ "lr": max(scheduler.get_last_lr()),
+ "gradient_step": train_state.completed_steps,
+ }
if not args.data_config.streaming:
metrics["epoch"] = step / len(train_dataloader)
@@ -488,7 +530,7 @@ def get_data_iter():
path = Path(args.out_dir).resolve() / f"checkpoint-{completed_steps}step"
save(path)
if do_eval:
- evaluate_multiple_dateloaders(eval_dataloaders)
+ evaluate_multiple_dateloaders(eval_dataloaders, use_full_evaluation_for_val=True)
if completed_steps >= args.max_train_steps:
# finished = True
diff --git a/script.sh b/script.sh
new file mode 100644
index 00000000..ed7e4f83
--- /dev/null
+++ b/script.sh
@@ -0,0 +1,39 @@
+source /fsx/home-jordiclive/miniconda3/bin/activate meta_conda
+cd /fsx/home-jordiclive/metadata
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/
+export TRANSFORMERS_CACHE=/fsx/home-jordiclive/transformers_cache
+
+#export HF_DATASETS_OFFLINE=1
+#export TRANSFORMERS_OFFLINE=1
+#export WANDB_MODE=offline
+export HYDRA_FULL_ERROR=1
+
+
+export MODEL=gpt2-xl
+export NUM_GPU=8
+export DEEPSPEED_CONFIG=$(realpath bsmetadata/deepspeed_configs/v2.json)
+export DATA_DIR=$(realpath /fsx/home-jordiclive/metadata/local-data/datasets--bs-modeling-metadata--c4-en-html-with-training_metadata_all/snapshots/8f2615d8b8580e89533b90bc3931e0b99ef15aec)
+echo "deepspeed_config_file: $DEEPSPEED_CONFIG"
+
+export WANDB_API_KEY= 'd8216641d549f9bb3d0c5074baa39e15dfd55030'
+
+echo "compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_config_file: $DEEPSPEED_CONFIG
+distributed_type: DEEPSPEED
+fp16: true
+machine_rank: 0
+main_process_ip: null
+main_process_port: null
+main_training_function: main
+num_machines: 1
+num_processes: -1
+mixed_precision: fp16
+" > accelerate_config.yaml
+CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file accelerate_config.yaml bsmetadata/train.py --config-name v2 \
+ model_name=$MODEL \
+ data_config.dataset_name=$DATA_DIR \
+ data_config.train_file='*.jsonl.gz' \
+ data_config.validation_file='c4-en-html_cc-main-2019-18_pq00-000.jsonl.gz' \
+ out_dir=/fsx/home-jordiclive/tmp/metadata-html-half \
+# wb_name="full-metadata-with-generation-text-0.5-html"
\ No newline at end of file
diff --git a/slurm_40.sh b/slurm_40.sh
new file mode 100644
index 00000000..7f483bb7
--- /dev/null
+++ b/slurm_40.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+#SBATCH --account laion
+#SBATCH --partition="g40"
+#SBATCH --job-name=flan
+#SBATCH --gres=gpu:8
+#SBATCH --ntasks-per-node=8
+#SBATCH --cpus-per-task=12
+#SBATCH --output=%x_%j.out
+source /fsx/home-jordiclive/miniconda3/bin/activate meta_conda
+cd /fsx/home-jordiclive/metadata
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/
+export TRANSFORMERS_CACHE=/fsx/home-jordiclive/transformers_cache
+
+#export HF_DATASETS_OFFLINE=1
+#export TRANSFORMERS_OFFLINE=1
+#export WANDB_MODE=offline
+export HYDRA_FULL_ERROR=1
+
+
+export MODEL=gpt2-xl
+export NUM_GPU=8
+export DEEPSPEED_CONFIG=$(realpath bsmetadata/deepspeed_configs/v2.json)
+export DATA_DIR=$(realpath /fsx/home-jordiclive/metadata/local-data/datasets--bs-modeling-metadata--c4-en-html-with-training_metadata_all/snapshots/8f2615d8b8580e89533b90bc3931e0b99ef15aec)
+echo "deepspeed_config_file: $DEEPSPEED_CONFIG"
+
+export WANDB_API_KEY= 'd8216641d549f9bb3d0c5074baa39e15dfd55030'
+
+echo "compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_config_file: $DEEPSPEED_CONFIG
+distributed_type: DEEPSPEED
+fp16: true
+machine_rank: 0
+main_process_ip: null
+main_process_port: null
+main_training_function: main
+num_machines: 1
+num_processes: -1
+mixed_precision: fp16
+" > accelerate_config.yaml
+CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file accelerate_config.yaml bsmetadata/train.py --config-name v2 \
+ model_name=$MODEL \
+ data_config.dataset_name=$DATA_DIR \
+ data_config.train_file='*.jsonl.gz' \
+ data_config.validation_file='c4-en-html_cc-main-2019-18_pq00-000.jsonl.gz' \
+ out_dir=/fsx/home-jordiclive/tmp/metadata-html-half \
+# wb_name="full-metadata-with-generation-text-0.5-html"
\ No newline at end of file
diff --git a/tests/test_metadata_utils.py b/tests/test_metadata_utils.py
index 49408078..51ecdcb8 100644
--- a/tests/test_metadata_utils.py
+++ b/tests/test_metadata_utils.py
@@ -454,6 +454,24 @@ def test_entity_settings(self):
"EntityOn |EntityParagraphOn ||| |United Kingdom| |Louis Vuitton| |Billy Connolly| |Something in Common| |Lembit Öpik| Hints and tips for media appearances, speaking and social media. This week; wall-to-wall politicians; Great Britain [[United Kingdom]]: Louis Vuitton [[Louis Vuitton]] condoms; Billy Connolly [[Billy Connolly]],; Lisa Dutton; Something in Common [[Something in Common]]; What was I saying?: We’re all publishers; An interview with Lembit Opik [[Lembit Öpik]]; Music from The Good Suns",
)
+ def test_html_special_token_settings(self):
+ # from transformers import AddedToken
+
+ cfg = MetadataConfig()
+ PROCESSORS["html"] = HtmlProcessor
+ cfg.metadata_list = ["html"]
+ cfg.treat_local_metadata_as_regular_text = True
+ cfg.local_metadata_special_token_start = {"html": ""}
+ cfg.local_metadata_special_token_end = {"html": ""}
+ text, mask = add_local_metadata_to_text(self.examples[1], cfg)
+ self.assertEqual(
+ text,
+ "An apple is an edible fruit "
+ "produced by an "
+ "apple tree"
+ " (Malus domestica).",
+ )
+
def test_add_local_metadata_to_text(self):
cfg = MetadataConfig()
cfg.metadata_list = ["html", "entity"]