From 53c702724e176a8ecbe0361bc3112c8f2a504601 Mon Sep 17 00:00:00 2001
From: servan <christophe.servan@lisn.upsaclay.fr>
Date: Mon, 12 Jun 2023 13:38:03 +0200
Subject: [PATCH] updated to change output directories

---
 _tagging-hf-MEDIA-COST.py | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/_tagging-hf-MEDIA-COST.py b/_tagging-hf-MEDIA-COST.py
index 065abf1..eb8063b 100644
--- a/_tagging-hf-MEDIA-COST.py
+++ b/_tagging-hf-MEDIA-COST.py
@@ -147,7 +147,7 @@ tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)
 model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))
 
 args = TrainingArguments(
-    f"models/media/"+model_checkpoint.replace("/","_"),
+    f"models/media/"+"outputs_" + model_checkpoint.replace("/","_"),
     evaluation_strategy = "epoch",
     learning_rate=2e-5,
     per_device_train_batch_size=batch_size,
@@ -166,8 +166,8 @@ metric = load_metric("seqeval")
 #example["tokens"]
 
 
-labels = [label_list[i] for i in example[f"{task}_tags"]]
-metric.compute(predictions=[labels], references=[labels])
+#labels = [label_list[i] for i in example[f"{task}_tags"]]
+#metric.compute(predictions=[labels], references=[labels])
 
 
 
@@ -216,10 +216,10 @@ trainer = Trainer(
 )
 
 
-if not os.path.exists(model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"):
-    os.makedirs(model_checkpoint.replace("/","_") + "/ImpactTrackerTrain")
-if not os.path.exists(model_checkpoint.replace("/","_") + "/ImpactTrackerTest"):
-    os.makedirs(model_checkpoint.replace("/","_") + "/ImpactTrackerTest")
+if not os.path.exists("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"):
+    os.makedirs("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain")
+if not os.path.exists("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest"):
+    os.makedirs("outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest")
 tracker = OfflineEmissionsTracker(country_iso_code="FRA")
 #tracker_plop = ImpactTracker('.')
 #tracker_plop.launch_impact_monitor()
@@ -232,7 +232,7 @@ to_train = True
 to_evaluate = True
 
 if to_train:
-    impact_dir = model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"
+    impact_dir = "outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTrain"
     myimpacttrackertrain = ImpactTracker(impact_dir)
     myimpacttrackertrain.launch_impact_monitor()
     trainer.train()
@@ -245,7 +245,7 @@ tracker_test = OfflineEmissionsTracker(country_iso_code="FRA")
 tracker_test.start()
 
 if to_evaluate:
-    impact_dir = model_checkpoint.replace("/","_") + "/ImpactTrackerTest"
+    impact_dir = "outputs_" + model_checkpoint.replace("/","_") + "/ImpactTrackerTest"
     myimpacttrackerinference = ImpactTracker(impact_dir)
     myimpacttrackerinference.launch_impact_monitor()
     trainer.evaluate(eval_dataset=tokenized_datasets["test"])
@@ -268,6 +268,6 @@ emissions: float = tracker_test.stop()
 #print(classification_report(true_predictions, true_labels))
 print(print_results(results))
 
-shutil.move("emissions.csv",model_checkpoint.replace("/","_") + "/emissions.csv")
+shutil.move("emissions.csv","outputs_" + model_checkpoint.replace("/","_") + "/emissions.csv")
 
 
-- 
GitLab