diff --git a/bsmetadata/hydra_configs/v2.yaml b/bsmetadata/hydra_configs/v2.yaml index 887fe5df..72cff715 100644 --- a/bsmetadata/hydra_configs/v2.yaml +++ b/bsmetadata/hydra_configs/v2.yaml @@ -40,15 +40,22 @@ data_config: #- generation_length_sentence - generation_length_text - entity_paragraph - local_metadata_special_tokens: - entity_paragraph: "entity" - metadata_sep: ' | ' - metadata_key_value_sep: ': ' + local_metadata_special_tokens: + entity_paragraph: "" + html: "" + prefix_sep_tokens: + title: "" + website_description: "" + datasource: "" + text_length: "" + url: "" + metadata_prefix_sep: '' + metadata_sep: '' + metadata_key_value_sep: '' + metadata_prefix_start_seq: '' metadata_probability: 0.5 treat_local_metadata_as_regular_text: true add_local_metadata_special_tokens_in_prefix: true - metadata_prefix_sep: ' |||' - metadata_prefix_start_seq: '' max_seq_len: 1024 html_parser_config: all_tags_rules: @@ -76,7 +83,7 @@ data_config: entity_paragraph: "" html: "" local_metadata_special_token_end: - entity_paragraph: " " + entity_paragraph: "" html: "" local_metadata_special_token_state: true html_overall_sample_rate: 0.25 diff --git a/bsmetadata/metadata_processors.py b/bsmetadata/metadata_processors.py index fe88ce36..312b1489 100644 --- a/bsmetadata/metadata_processors.py +++ b/bsmetadata/metadata_processors.py @@ -141,7 +141,7 @@ class MetadataConfig: }, ) metadata_prefix_sep: str = field( - default=" |||", + default="", metadata={ "help": "The character sequence that is used to separate all global metadata and/or local metadata " "special tokens (if `add_local_metadata_special_tokens_in_prefix` is `True`) from the actual text." @@ -351,7 +351,9 @@ class UrlProcessor(MetadataProcessor): def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]: # We represent a URL with unquoted format such that less confusion for a tokenizer. # Example: "foo.bar/Year 2021/" instead of "foo.bar/Year%202021/". - return "".join([metadata_attrs["key"], self.cfg.metadata_key_value_sep, unquote_plus(metadata_attrs["value"])]) + return "".join( + [self.cfg.prefix_sep_tokens["url"], self.cfg.metadata_key_value_sep, unquote_plus(metadata_attrs["value"])] + ) class TitleProcessor(MetadataProcessor): @@ -360,7 +362,7 @@ class TitleProcessor(MetadataProcessor): def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]: # We represent a title by the title of the corresponding webpage content. # Example: "My Thoughts On It ยป Dad, I want to be an inventor". - return "".join(["Title", self.cfg.metadata_key_value_sep, metadata_attrs["value"]]) + return "".join([self.cfg.prefix_sep_tokens["title"], self.cfg.metadata_key_value_sep, metadata_attrs["value"]]) class WebsiteDescriptionProcessor(MetadataProcessor): @@ -368,7 +370,13 @@ class WebsiteDescriptionProcessor(MetadataProcessor): def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]: # Example: "website_description: BBC is a news organization". - return "".join(["Website Description", self.cfg.metadata_key_value_sep, metadata_attrs["value"]]) + return "".join( + [ + self.cfg.prefix_sep_tokens["website_description"], + self.cfg.metadata_key_value_sep, + metadata_attrs["value"], + ] + ) class DatasourceProcessor(MetadataProcessor): @@ -378,7 +386,9 @@ def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]: # We represent the DATASOURCE by using meaningful information of the URL. # URL: http://www.example.de/2015/forum/article/21-new-project # Example: example.de > forum > article > new project - return "".join(["Datasource", self.cfg.metadata_key_value_sep, metadata_attrs["value"]]) + return "".join( + [self.cfg.prefix_sep_tokens["datasource"], self.cfg.metadata_key_value_sep, metadata_attrs["value"]] + ) class GenerationLengthProcessor(MetadataProcessor): @@ -388,7 +398,9 @@ def process_global(self, metadata_attrs: Dict[str, Any]) -> Optional[str]: # We represent the length of a text by the number of characters. # Example: Length: 123 - return "".join(["Text Length", self.cfg.metadata_key_value_sep, metadata_attrs["value"]]) + return "".join( + [self.cfg.prefix_sep_tokens["text_length"], self.cfg.metadata_key_value_sep, metadata_attrs["value"]] + ) class BasicStartLocalProcessor(MetadataProcessor): diff --git a/bsmetadata/train.py b/bsmetadata/train.py index d97853b6..57a17d26 100644 --- a/bsmetadata/train.py +++ b/bsmetadata/train.py @@ -241,6 +241,9 @@ def main(args: CFG) -> None: ) ) ) + new_tokens.append(args.data_config.metadata_config.metadata_prefix_sep) + new_tokens.extend(args.data_config.metadata_config.prefix_sep_tokens.values()) + new_tokens.extend(args.data_config.metadata_config.local_metadata_special_tokens.values()) new_tokens = [ AddedToken(token, rstrip=False, lstrip=False, single_word=False, normalized=False) for token in new_tokens ]