Skip to content

Commit

Permalink
fix(weight): if weight doesn't exits in the folder then download it
Browse files Browse the repository at this point in the history
fixes bug on the weight not found issue
  • Loading branch information
kshitijrajsharma committed Jan 12, 2025
1 parent 9c6616a commit 9039a3d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 28 deletions.
47 changes: 33 additions & 14 deletions hot_fair_utilities/training/yolo_v8_v1/train.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# Standard library imports
import argparse
import os
from pathlib import Path
import urllib.request

# Third party imports
import torch
import ultralytics


# Reader imports
from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight
from ...utils import compute_iou_chart_from_yolo_results, get_yolo_iou_metrics,export_model_to_onnx

from ...utils import (
compute_iou_chart_from_yolo_results,
export_model_to_onnx,
get_yolo_iou_metrics,
)

# Get environment variables with fallbacks
# ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute()))
# DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training")))
Expand Down Expand Up @@ -91,6 +96,12 @@ def train(
output_path=None,
dataset_yaml_path=None,
):
if not os.path.exists(weights):
url = "https://github.com/hotosm/fAIr-utilities/raw/refs/heads/master/yolov8s_v1-seg-best.pt"
download_path = weights
urllib.request.urlretrieve(url, download_path)
weights = download_path
print(f"Weights file downloaded to {weights}")
back = (
"n"
if "yolov8n" in weights
Expand All @@ -110,33 +121,41 @@ def train(
else:
yolo = ultralytics.YOLO

weights, resume = check4checkpoint(name, weights,output_path)
weights, resume = check4checkpoint(name, weights, output_path)
model = yolo(weights)
model.train(
data=data_scn,
project=os.path.join(output_path,'checkpoints'),
project=os.path.join(output_path, "checkpoints"),
name=name,
epochs=int(epochs),
resume=resume,
deterministic=False,
verbose=True,
save_dir= os.path.join(output_path),
save_dir=os.path.join(output_path),
device=[int(i) for i in gpu.split(",")] if "," in gpu else gpu,
**kwargs,
)
compute_iou_chart_from_yolo_results(results_csv_path=os.path.join(output_path,"checkpoints", name,'results.csv'),results_output_chart_path=os.path.join(output_path,"checkpoints", name,'iou_chart.png'))

output_model_path=os.path.join(os.path.join(output_path,"checkpoints"), name, "weights", "best.pt")
compute_iou_chart_from_yolo_results(
results_csv_path=os.path.join(output_path, "checkpoints", name, "results.csv"),
results_output_chart_path=os.path.join(
output_path, "checkpoints", name, "iou_chart.png"
),
)

iou_model_accuracy=get_yolo_iou_metrics(output_model_path)
export_model_to_onnx(output_model_path)
output_model_path = os.path.join(
os.path.join(output_path, "checkpoints"), name, "weights", "best.pt"
)

return output_model_path,iou_model_accuracy
iou_model_accuracy = get_yolo_iou_metrics(output_model_path)
export_model_to_onnx(output_model_path)

return output_model_path, iou_model_accuracy


def check4checkpoint(name, weights,output_path):
ckpt = os.path.join(os.path.join(output_path,'checkpoints'), name, "weights", "last.pt")
def check4checkpoint(name, weights, output_path):
ckpt = os.path.join(
os.path.join(output_path, "checkpoints"), name, "weights", "last.pt"
)
if os.path.exists(ckpt):
print(f"Set weights to {ckpt}")
return ckpt, True
Expand Down
60 changes: 46 additions & 14 deletions hot_fair_utilities/training/yolo_v8_v2/train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# Standard library imports
import argparse
import os
from pathlib import Path
import urllib.request

# Third party imports
import torch
import ultralytics
from ...utils import get_yolo_iou_metrics,compute_iou_chart_from_yolo_results,export_model_to_onnx

# Reader imports
from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight

from ...utils import (
compute_iou_chart_from_yolo_results,
export_model_to_onnx,
get_yolo_iou_metrics,
)

# ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute()))
# DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training")))

Expand Down Expand Up @@ -51,7 +56,23 @@
}


def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path,gpu=("cuda" if torch.cuda.is_available() else "cpu"),):
def train(
data,
weights,
epochs,
batch_size,
pc,
output_path,
dataset_yaml_path,
gpu=("cuda" if torch.cuda.is_available() else "cpu"),
):
if not os.path.exists(weights):
url = "https://github.com/hotosm/fAIr-utilities/raw/refs/heads/master/yolov8s_v2-seg.pt"
download_path = weights
urllib.request.urlretrieve(url, download_path)
weights = download_path
print(f"Weights file downloaded to {weights}")

back = (
"n"
if "yolov8n" in weights
Expand All @@ -71,36 +92,47 @@ def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path,
else:
yolo = ultralytics.YOLO

weights, resume = check4checkpoint(name, weights,output_path)
weights, resume = check4checkpoint(name, weights, output_path)
model = yolo(weights)

model.train(
data=data_scn,
project=os.path.join(output_path,"checkpoints"), # Using the environment variable with fallback
project=os.path.join(
output_path, "checkpoints"
), # Using the environment variable with fallback
name=name,
epochs=int(epochs),
resume=resume,
verbose=True,
deterministic=False,
save_dir= os.path.join(output_path),
save_dir=os.path.join(output_path),
device=[int(i) for i in gpu.split(",")] if "," in gpu else gpu,
**kwargs,
)

# metrics = model.val(save_json=True, plots=True)
# print(model.val())
compute_iou_chart_from_yolo_results(results_csv_path=os.path.join(output_path,"checkpoints", name,'results.csv'),results_output_chart_path=os.path.join(output_path,"checkpoints", name,'iou_chart.png'))

output_model_path=os.path.join(os.path.join(output_path,"checkpoints"), name, "weights", "best.pt")
compute_iou_chart_from_yolo_results(
results_csv_path=os.path.join(output_path, "checkpoints", name, "results.csv"),
results_output_chart_path=os.path.join(
output_path, "checkpoints", name, "iou_chart.png"
),
)

iou_model_accuracy=get_yolo_iou_metrics(output_model_path)
output_model_path = os.path.join(
os.path.join(output_path, "checkpoints"), name, "weights", "best.pt"
)

iou_model_accuracy = get_yolo_iou_metrics(output_model_path)
export_model_to_onnx(output_model_path)

return output_model_path,iou_model_accuracy
return output_model_path, iou_model_accuracy


def check4checkpoint(name, weights,output_path):
ckpt = os.path.join(os.path.join(output_path,"checkpoints"), name, "weights", "last.pt")
def check4checkpoint(name, weights, output_path):
ckpt = os.path.join(
os.path.join(output_path, "checkpoints"), name, "weights", "last.pt"
)
if os.path.exists(ckpt):
print(f"Set weights to {ckpt}")
return ckpt, True
Expand Down

0 comments on commit 9039a3d

Please sign in to comment.