Skip to content
Snippets Groups Projects
Commit 53c70272 authored by servan's avatar servan
Browse files

updated to change output directories

parent da6bc303
No related branches found
No related tags found
No related merge requests found
......@@ -147,7 +147,7 @@ tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))
args = TrainingArguments(
f"models/media/"+model_checkpoint.replace("/","_"),
f"models/media/"+"outputs_" + model_checkpoint.replace("/","_"),
evaluation_strategy = "epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
......@@ -166,8 +166,8 @@ metric = load_metric("seqeval")
#example["tokens"]
labels = [label_list[i] for i in example[f"{task}_tags"]]
metric.compute(predictions=[labels], references=[labels])
#labels = [label_list[i] for i in example[f"{task}_tags"]]
#metric.compute(predictions=[labels], references=[labels])
......@@ -216,10 +216,10 @@ trainer = Trainer(
)
if not os.path.exists(model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"):
os.makedirs(model_checkpoint.replace("/","_") + "/ImpactTrackerTrain")
if not os.path.exists(model_checkpoint.replace("/","_") + "/ImpactTrackerTest"):
os.makedirs(model_checkpoint.replace("/","_") + "/ImpactTrackerTest")
if not os.path.exists("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"):
os.makedirs("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain")
if not os.path.exists("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest"):
os.makedirs("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest")
tracker = OfflineEmissionsTracker(country_iso_code="FRA")
#tracker_plop = ImpactTracker('.')
#tracker_plop.launch_impact_monitor()
......@@ -232,7 +232,7 @@ to_train = True
to_evaluate = True
if to_train:
impact_dir = model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"
impact_dir = "outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"
myimpacttrackertrain = ImpactTracker(impact_dir)
myimpacttrackertrain.launch_impact_monitor()
trainer.train()
......@@ -245,7 +245,7 @@ tracker_test = OfflineEmissionsTracker(country_iso_code="FRA")
tracker_test.start()
if to_evaluate:
impact_dir = model_checkpoint.replace("/","_") + "/ImpactTrackerTest"
impact_dir = "outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest"
myimpacttrackerinference = ImpactTracker(impact_dir)
myimpacttrackerinference.launch_impact_monitor()
trainer.evaluate(eval_dataset=tokenized_datasets["test"])
......@@ -268,6 +268,6 @@ emissions: float = tracker_test.stop()
#print(classification_report(true_predictions, true_labels))
print(print_results(results))
shutil.move("emissions.csv",model_checkpoint.replace("/","_") + "/emissions.csv")
shutil.move("emissions.csv","outputs_" + model_checkpoint.replace("/","_") + "/emissions.csv")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment