Skip to content

[WIP] Experiments with BART-based Generative QA #36

@NISH1001

Description

@NISH1001

Currently, I am experimenting with BART which is encoder-decoder model, which is mostly used as seq2seq form (mainly for summarization and translation). But, I am able to train the model for Question Answering (generative QA).

On a very rudimentary run for askathon data (128 samples), I fine-tuned and overfitted the vanilla bart model to see if it can work nicely.

Then, used evalem to evaluate and generate comparison table.

metric askathon-tuned bart-vanilla
MeteorMetric 0.456912 0.130023
BleuMetric 0.211337 0.0652976
F1Metric 0.530609 0.142459
AccuracyMetric 0.494754 0.0874828
RougeMetric 0.48671 0.101733
BertScore 0.690443 0.41951
BartScore -3.49565 -5.38511
ExactMatchMetric 0.421875 0

evalem code

The evalem code is a moneky-patch where I have created a new temporary component for Generative QA.

I) Generative QA component (temporary for now)

from transformers import BartForConditionalGeneration, BartTokenizer, TrainingArguments, Trainer
from typing import Iterable

from tqdm import tqdm
import pandas as pd

from evalem import NamedSimpleEvaluationPipeline

from evalem.nlp.models import QuestionAnsweringHFPipelineWrapper
from evalem.nlp.evaluators import QAEvaluator
from evalem.nlp.metrics import BertScore, RougeMetric, MeteorMetric, ExactMatchMetric, BartScore, BleuMetric
from evalem.nlp.structures import QuestionAnsweringDTO

from evalem.nlp.models._base import HFLMWrapper

from evalem.misc.utils import build_comparison_table

class GenerativeBartQAWrapper(HFLMWrapper):
    
#     def _predict(self, inputs, **kwargs):
#         gen_ids = self.model.generate(
#             inputs["input_ids"],
#             attention_mask=inputs["attention_mask"]
#         )
#         return self.token_ids_to_token(gen_ids)
    
    def _predict(self, inputs, **kwargs):
        res = []
        
        batch_size = kwargs.get("batch_size", 8)
        logger.debug(f"batch_size={batch_size}")
        n_items = len(inputs["input_ids"])
        with tqdm(total=int(n_items/batch_size)) as pbar:
            for input_ids, attention_mask in tqdm(zip(
                self.batch_iterator(inputs["input_ids"], batch_size),
                self.batch_iterator(inputs["attention_mask"], batch_size)
            )):
                gen_ids = self.model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                )
                tokens = self.token_ids_to_token(gen_ids)
                res.extend(tokens)
                pbar.update()
        return res

    @staticmethod
    def batch_iterator(iterable, batch_size):
        for start in range(0, len(iterable), batch_size):
            yield iterable[start:start + batch_size]

    
    def token_ids_to_token(self, token_ids):
        return [
            self.tokenizer.decode(token_id, skip_special_tokens=True) for token_id in token_ids
        ]
    
    def _preprocess_inputs(self, inputs: Iterable, **kwargs) -> Iterable:
        """
        A helper method to transform inputs suitable for model to ingest.
        By default, it's an identity function.
        """
        input_texts = []
        labels = []
        
        for dct in inputs:
            ctx, q, a = dct["context"], dct["question"], dct.get("answer")
            input_texts.append(f"context: {ctx}\nquestion: {q}")
            labels.append(a)
        
        tokenized_inputs = self.tokenizer(input_texts, truncation=True, padding="longest", return_tensors="pt")

        return dict(
            input_ids=tokenized_inputs["input_ids"].to(self.model.device),
            attention_mask=tokenized_inputs["attention_mask"].to(self.model.device)
        )

II) Connecting the components

DEVICE = "mps"
BATCH_SIZE = 8

tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
predictions_postprocessor = lambda x: list(map(lambda p: QuestionAnsweringDTO(value=p), x))

wrapped_model_1 = GenerativeBartQAWrapper(
    model=BartForConditionalGeneration.from_pretrained("tmp/bart-askathon-v1/").to(DEVICE),
    tokenizer=tokenizer,
    predictions_postprocessor=predictions_postprocessor,
)

wrapped_model_2 = GenerativeBartQAWrapper(
    model=BartForConditionalGeneration.from_pretrained("facebook/bart-large").to(DEVICE),
    tokenizer=tokenizer,
    predictions_postprocessor=predictions_postprocessor,
)

data = pd.DataFrame(get_askathon_data("data/askathon.csv"))\
    .rename(columns={"contexts": "context", "questions": "question", "answers": "answer"})#.to_dict("records")

evaluators_common = [
    QAEvaluator(),
    BertScore(device="mps"),
    BartScore(device="mps"),
    RougeMetric(),
    MeteorMetric(),
    BleuMetric(),
]

eval_pipe_1 = NamedSimpleEvaluationPipeline(
    model=wrapped_model_1,
    evaluators=evaluators_common,
    name="askathon-tuned"
)

eval_pipe_2 = NamedSimpleEvaluationPipeline(
    model=wrapped_model_2,
    evaluators=evaluators_common,
    name="bart-vanilla"
)

results = build_comparison_table(
    eval_pipe_1, eval_pipe_2,
    inputs=list(data[["context", "question"]].T.to_dict().values()),
    references=data["answer"].to_list(),
)

III) Askathon dataloader

def load_askathon_clean(path: str) -> pd.DataFrame:
    data = pd.read_csv(path)
    data = data.drop(columns=["Email Address"]).reset_index(drop=True)
    data.rename(columns={
        data.columns[0] : "context",
        data.columns[1]: "id",
        data.columns[2]: "source",
        data.columns[3]: "topics",
        data.columns[4]: "q1",
        data.columns[5]: "a1",
        data.columns[6]: "q2",
        data.columns[7]: "a2",
        data.columns[8]: "q3",
        data.columns[9]: "a3",
        data.columns[10]: "q4",
        data.columns[11]: "a4",
        data.columns[12]: "q5",
        data.columns[13]: "a5"
    }, inplace=True)
    data.drop(columns=["source", "topics"], inplace=True)
    return data

def create_qa_dataset(data: pd.DataFrame) -> pd.DataFrame:
    res = []
    q_keys = [f"q{i}" for i in range(1, 6)]
    a_keys = [f"a{i}" for i in range(1, 6)]
    
    def _index_fn(context: str, answer: str) -> int:
        try:
            return context.lower().index(answer.rstrip(" ,.!?").lower())
        except ValueError:
            return -1
    
    for _df in data.itertuples():
        tmp = []
        for qk, ak in zip(q_keys, a_keys):
            q, a = getattr(_df, qk), getattr(_df, ak)
            
            if not isinstance(a, str):
                continue
            idx = _index_fn(_df.context, a)
            if idx > -1:
                tmp.append(dict(
                    id=str(_df.id),
                    context=_df.context,
                    question=q,
                    answer_text=a,
                    answer_start=idx,
                ))
        res.extend(tmp)
    return pd.DataFrame(res)

cc: @muthukumaranR @xhagrg

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions