-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Currently, I am experimenting with BART which is encoder-decoder model, which is mostly used as seq2seq form (mainly for summarization and translation). But, I am able to train the model for Question Answering (generative QA).
On a very rudimentary run for askathon data (128 samples), I fine-tuned and overfitted the vanilla bart model to see if it can work nicely.
Then, used evalem to evaluate and generate comparison table.
| metric | askathon-tuned | bart-vanilla |
|---|---|---|
| MeteorMetric | 0.456912 | 0.130023 |
| BleuMetric | 0.211337 | 0.0652976 |
| F1Metric | 0.530609 | 0.142459 |
| AccuracyMetric | 0.494754 | 0.0874828 |
| RougeMetric | 0.48671 | 0.101733 |
| BertScore | 0.690443 | 0.41951 |
| BartScore | -3.49565 | -5.38511 |
| ExactMatchMetric | 0.421875 | 0 |
evalem code
The evalem code is a moneky-patch where I have created a new temporary component for Generative QA.
I) Generative QA component (temporary for now)
from transformers import BartForConditionalGeneration, BartTokenizer, TrainingArguments, Trainer
from typing import Iterable
from tqdm import tqdm
import pandas as pd
from evalem import NamedSimpleEvaluationPipeline
from evalem.nlp.models import QuestionAnsweringHFPipelineWrapper
from evalem.nlp.evaluators import QAEvaluator
from evalem.nlp.metrics import BertScore, RougeMetric, MeteorMetric, ExactMatchMetric, BartScore, BleuMetric
from evalem.nlp.structures import QuestionAnsweringDTO
from evalem.nlp.models._base import HFLMWrapper
from evalem.misc.utils import build_comparison_table
class GenerativeBartQAWrapper(HFLMWrapper):
# def _predict(self, inputs, **kwargs):
# gen_ids = self.model.generate(
# inputs["input_ids"],
# attention_mask=inputs["attention_mask"]
# )
# return self.token_ids_to_token(gen_ids)
def _predict(self, inputs, **kwargs):
res = []
batch_size = kwargs.get("batch_size", 8)
logger.debug(f"batch_size={batch_size}")
n_items = len(inputs["input_ids"])
with tqdm(total=int(n_items/batch_size)) as pbar:
for input_ids, attention_mask in tqdm(zip(
self.batch_iterator(inputs["input_ids"], batch_size),
self.batch_iterator(inputs["attention_mask"], batch_size)
)):
gen_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
)
tokens = self.token_ids_to_token(gen_ids)
res.extend(tokens)
pbar.update()
return res
@staticmethod
def batch_iterator(iterable, batch_size):
for start in range(0, len(iterable), batch_size):
yield iterable[start:start + batch_size]
def token_ids_to_token(self, token_ids):
return [
self.tokenizer.decode(token_id, skip_special_tokens=True) for token_id in token_ids
]
def _preprocess_inputs(self, inputs: Iterable, **kwargs) -> Iterable:
"""
A helper method to transform inputs suitable for model to ingest.
By default, it's an identity function.
"""
input_texts = []
labels = []
for dct in inputs:
ctx, q, a = dct["context"], dct["question"], dct.get("answer")
input_texts.append(f"context: {ctx}\nquestion: {q}")
labels.append(a)
tokenized_inputs = self.tokenizer(input_texts, truncation=True, padding="longest", return_tensors="pt")
return dict(
input_ids=tokenized_inputs["input_ids"].to(self.model.device),
attention_mask=tokenized_inputs["attention_mask"].to(self.model.device)
)II) Connecting the components
DEVICE = "mps"
BATCH_SIZE = 8
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
predictions_postprocessor = lambda x: list(map(lambda p: QuestionAnsweringDTO(value=p), x))
wrapped_model_1 = GenerativeBartQAWrapper(
model=BartForConditionalGeneration.from_pretrained("tmp/bart-askathon-v1/").to(DEVICE),
tokenizer=tokenizer,
predictions_postprocessor=predictions_postprocessor,
)
wrapped_model_2 = GenerativeBartQAWrapper(
model=BartForConditionalGeneration.from_pretrained("facebook/bart-large").to(DEVICE),
tokenizer=tokenizer,
predictions_postprocessor=predictions_postprocessor,
)
data = pd.DataFrame(get_askathon_data("data/askathon.csv"))\
.rename(columns={"contexts": "context", "questions": "question", "answers": "answer"})#.to_dict("records")
evaluators_common = [
QAEvaluator(),
BertScore(device="mps"),
BartScore(device="mps"),
RougeMetric(),
MeteorMetric(),
BleuMetric(),
]
eval_pipe_1 = NamedSimpleEvaluationPipeline(
model=wrapped_model_1,
evaluators=evaluators_common,
name="askathon-tuned"
)
eval_pipe_2 = NamedSimpleEvaluationPipeline(
model=wrapped_model_2,
evaluators=evaluators_common,
name="bart-vanilla"
)
results = build_comparison_table(
eval_pipe_1, eval_pipe_2,
inputs=list(data[["context", "question"]].T.to_dict().values()),
references=data["answer"].to_list(),
)III) Askathon dataloader
def load_askathon_clean(path: str) -> pd.DataFrame:
data = pd.read_csv(path)
data = data.drop(columns=["Email Address"]).reset_index(drop=True)
data.rename(columns={
data.columns[0] : "context",
data.columns[1]: "id",
data.columns[2]: "source",
data.columns[3]: "topics",
data.columns[4]: "q1",
data.columns[5]: "a1",
data.columns[6]: "q2",
data.columns[7]: "a2",
data.columns[8]: "q3",
data.columns[9]: "a3",
data.columns[10]: "q4",
data.columns[11]: "a4",
data.columns[12]: "q5",
data.columns[13]: "a5"
}, inplace=True)
data.drop(columns=["source", "topics"], inplace=True)
return data
def create_qa_dataset(data: pd.DataFrame) -> pd.DataFrame:
res = []
q_keys = [f"q{i}" for i in range(1, 6)]
a_keys = [f"a{i}" for i in range(1, 6)]
def _index_fn(context: str, answer: str) -> int:
try:
return context.lower().index(answer.rstrip(" ,.!?").lower())
except ValueError:
return -1
for _df in data.itertuples():
tmp = []
for qk, ak in zip(q_keys, a_keys):
q, a = getattr(_df, qk), getattr(_df, ak)
if not isinstance(a, str):
continue
idx = _index_fn(_df.context, a)
if idx > -1:
tmp.append(dict(
id=str(_df.id),
context=_df.context,
question=q,
answer_text=a,
answer_start=idx,
))
res.extend(tmp)
return pd.DataFrame(res)Metadata
Metadata
Assignees
Labels
No labels