Skip to content
Snippets Groups Projects
_tagging-hf-MEDIA-COST.py 8.28 KiB
Newer Older
servan's avatar
servan committed
#!/usr/bin/env python
# coding: utf-8



from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML
from codecarbon import OfflineEmissionsTracker
from experiment_impact_tracker.compute_tracker import ImpactTracker
import argparse
import shutil
import os
from datasets import load_dataset, load_metric
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
import numpy as np
import tensorflow as tf
from sklearn.metrics import classification_report

parser = argparse.ArgumentParser()

    # Required parameters
parser.add_argument(
        "--model",
        default=None,
        type=str,
        required=True,
        help="Model given through huggingface (e.g.: qwant/fralbert-base)",
    )
parser.add_argument(
        "--epoch",
        default=10,
        type=int,
        required=False,
        help="Number of epochs (10 by default)",
    )
parser.add_argument(
        "--to_train",
        default=True,
        type=bool,
        required=False,
        help="Launch training",
    )
parser.add_argument(
        "--to_eval",
        default=True,
        type=bool,
        required=False,
        help="Launch evaluation",
    )
argparsed = parser.parse_args()


os.environ["WANDB_DISABLED"] = "true"
os.path.isdir("./")

datasets = load_dataset(
    path = "media.py",
    data_files={
        "train":"train.txt",
        "dev": "dev.txt",
        "test": "test.txt"
    }
)   

label_list = datasets["train"].features[f"chunk_tags"].feature.names



def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

show_random_elements(datasets["train"])


#model_checkpoint = "xlm-roberta-large" #ND
#model_checkpoint = "qwant/fralbert-base"
#model_checkpoint = "camembert/camembert-base-wikipedia-4gb"
model_checkpoint = argparsed.model
#qwant/fralbert-base
#xlm-roberta-base
#xlm-roberta-large
#Geotrend/bert-base-fr-cased
#distilbert-base-multilingual-cased
#albert-base-v2
#camembert-base
#camembert/camembert-large
#bert-base-multilingual-cased


task = "chunk" # Should be one of "ner", "pos" or "chunk"
batch_size = 8

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)


label_all_tokens = True

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs


tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)


model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

args = TrainingArguments(
    f"models/media/"+"outputs_" + model_checkpoint.replace("/","_"),
servan's avatar
servan committed
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=argparsed.epoch,
    weight_decay=0.01,
)


data_collator = DataCollatorForTokenClassification(tokenizer)


metric = load_metric("seqeval")

#example = datasets["train"][22]
#example["tokens"]


#labels = [label_list[i] for i in example[f"{task}_tags"]]
#metric.compute(predictions=[labels], references=[labels])
servan's avatar
servan committed



def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

def print_results(results):
    strout = "Precision: " + str(round(100*results["overall_precision"],2))
    strout = strout + "\nRecall: " + str(round(100*results["overall_recall"],2))
    strout = strout + "\nF1: " + str(round(100*results["overall_f1"],2))
    strout = strout + "\nAccuracy: " + str(round(100*results["overall_accuracy"],2))
    return strout


args.metric_for_best_model="f1", 
args.greater_is_better=True
    
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,

)


if not os.path.exists("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"):
    os.makedirs("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain")
if not os.path.exists("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest"):
    os.makedirs("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest")
servan's avatar
servan committed
tracker = OfflineEmissionsTracker(country_iso_code="FRA")
#tracker_plop = ImpactTracker('.')
#tracker_plop.launch_impact_monitor()
tracker.start()
#print(tracker.gpu_ids)
#tracker.log_level = DEBUG


to_train = True
to_evaluate = True

if to_train:
    impact_dir = "outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"
servan's avatar
servan committed
    myimpacttrackertrain = ImpactTracker(impact_dir)
    myimpacttrackertrain.launch_impact_monitor()
    trainer.train()


emissions: float = tracker.stop()

trainer.evaluate(eval_dataset=tokenized_datasets["validation"])
tracker_test = OfflineEmissionsTracker(country_iso_code="FRA")
tracker_test.start()

if to_evaluate:
    impact_dir = "outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest"
servan's avatar
servan committed
    myimpacttrackerinference = ImpactTracker(impact_dir)
    myimpacttrackerinference.launch_impact_monitor()
    trainer.evaluate(eval_dataset=tokenized_datasets["test"])
    predictions, labels, _ = trainer.predict(tokenized_datasets["test"])
    predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
#print(true_predictions)
#print(true_labels)
results = metric.compute(predictions=true_predictions, references=true_labels)
emissions: float = tracker_test.stop()
#print(classification_report(true_predictions, true_labels))
print(print_results(results))

shutil.move("emissions.csv","outputs_" + model_checkpoint.replace("/","_") + "/emissions.csv")
servan's avatar
servan committed