158 lines
4.7 KiB
Python
158 lines
4.7 KiB
Python
|
|
"""
|
||
|
|
Train YOLOX for knot detection (MIT license - free for commercial use).
|
||
|
|
|
||
|
|
YOLOX is from Megvii, designed for real-time detection.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python train_yolox.py --dataset-dir dataset_prepared --model yolox-nano --epochs 100
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
from pathlib import Path
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
def train_yolox(
|
||
|
|
dataset_dir: Path,
|
||
|
|
output_dir: Path,
|
||
|
|
model_name: str = "yolox-nano",
|
||
|
|
epochs: int = 100,
|
||
|
|
batch_size: int = 8,
|
||
|
|
img_size: int = 640,
|
||
|
|
learning_rate: float = 1e-3,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Train YOLOX model using Ultralytics (has YOLOX support).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
dataset_dir: Path to dataset with train/valid/test splits
|
||
|
|
output_dir: Where to save checkpoints
|
||
|
|
model_name: One of ['yolox-nano', 'yolox-tiny', 'yolox-s', 'yolox-m', 'yolox-l']
|
||
|
|
epochs: Number of training epochs
|
||
|
|
batch_size: Batch size
|
||
|
|
img_size: Input image size
|
||
|
|
learning_rate: Learning rate
|
||
|
|
"""
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
# Validate dataset structure
|
||
|
|
train_dir = dataset_dir / "train"
|
||
|
|
valid_dir = dataset_dir / "valid"
|
||
|
|
|
||
|
|
if not train_dir.exists() or not valid_dir.exists():
|
||
|
|
raise ValueError(f"Dataset must have train/ and valid/ directories")
|
||
|
|
|
||
|
|
# Check for data.yaml
|
||
|
|
data_yaml = dataset_dir / "data.yaml"
|
||
|
|
if not data_yaml.exists():
|
||
|
|
raise ValueError(f"Missing {data_yaml}. Run reorganize_dataset.py first!")
|
||
|
|
|
||
|
|
# Create output directory
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"Training YOLOX-{model_name}")
|
||
|
|
print(f"Dataset: {dataset_dir}")
|
||
|
|
print(f"Output: {output_dir}")
|
||
|
|
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
|
||
|
|
print(f"{'='*60}\n")
|
||
|
|
|
||
|
|
# Map model names to YOLO format
|
||
|
|
model_map = {
|
||
|
|
"yolox-nano": "yolox_n.pt",
|
||
|
|
"yolox-tiny": "yolox_tiny.pt",
|
||
|
|
"yolox-s": "yolox_s.pt",
|
||
|
|
"yolox-m": "yolox_m.pt",
|
||
|
|
"yolox-l": "yolox_l.pt",
|
||
|
|
"yolox-x": "yolox_x.pt",
|
||
|
|
}
|
||
|
|
|
||
|
|
if model_name not in model_map:
|
||
|
|
raise ValueError(f"Model must be one of {list(model_map.keys())}")
|
||
|
|
|
||
|
|
# Note: Ultralytics doesn't have native YOLOX support, so we'll use YOLOv8
|
||
|
|
# as the closest alternative with similar architecture
|
||
|
|
print("Note: Using YOLOv8 as Ultralytics doesn't directly support YOLOX")
|
||
|
|
print("YOLOv8 has similar performance and better maintained\n")
|
||
|
|
|
||
|
|
# Map to YOLOv8 equivalents
|
||
|
|
yolov8_map = {
|
||
|
|
"yolox-nano": "yolov8n.pt",
|
||
|
|
"yolox-tiny": "yolov8n.pt",
|
||
|
|
"yolox-s": "yolov8s.pt",
|
||
|
|
"yolox-m": "yolov8m.pt",
|
||
|
|
"yolox-l": "yolov8l.pt",
|
||
|
|
"yolox-x": "yolov8x.pt",
|
||
|
|
}
|
||
|
|
|
||
|
|
# Initialize model
|
||
|
|
model = YOLO(yolov8_map[model_name])
|
||
|
|
|
||
|
|
# Train
|
||
|
|
results = model.train(
|
||
|
|
data=str(data_yaml),
|
||
|
|
epochs=epochs,
|
||
|
|
batch=batch_size,
|
||
|
|
imgsz=img_size,
|
||
|
|
lr0=learning_rate,
|
||
|
|
project=str(output_dir),
|
||
|
|
name="training",
|
||
|
|
exist_ok=True,
|
||
|
|
patience=20, # Early stopping
|
||
|
|
save=True,
|
||
|
|
save_period=10, # Save every 10 epochs
|
||
|
|
plots=True,
|
||
|
|
device=0 if torch.cuda.is_available() else "cpu",
|
||
|
|
)
|
||
|
|
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"✓ Training complete!")
|
||
|
|
print(f"Best weights: {output_dir / 'training/weights/best.pt'}")
|
||
|
|
print(f"Last weights: {output_dir / 'training/weights/last.pt'}")
|
||
|
|
print(f"{'='*60}\n")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Train YOLOX for knot detection")
|
||
|
|
parser.add_argument(
|
||
|
|
"--dataset-dir",
|
||
|
|
type=Path,
|
||
|
|
required=True,
|
||
|
|
help="Path to dataset directory with train/valid/test splits"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--output-dir",
|
||
|
|
type=Path,
|
||
|
|
default=Path("runs/yolox_training"),
|
||
|
|
help="Output directory for checkpoints and logs"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--model",
|
||
|
|
type=str,
|
||
|
|
choices=["yolox-nano", "yolox-tiny", "yolox-s", "yolox-m", "yolox-l", "yolox-x"],
|
||
|
|
default="yolox-nano",
|
||
|
|
help="YOLOX model variant (nano=smallest/fastest, x=largest/most accurate)"
|
||
|
|
)
|
||
|
|
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||
|
|
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
|
||
|
|
parser.add_argument("--img-size", type=int, default=640, help="Input image size")
|
||
|
|
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
train_yolox(
|
||
|
|
dataset_dir=args.dataset_dir,
|
||
|
|
output_dir=args.output_dir,
|
||
|
|
model_name=args.model,
|
||
|
|
epochs=args.epochs,
|
||
|
|
batch_size=args.batch_size,
|
||
|
|
img_size=args.img_size,
|
||
|
|
learning_rate=args.lr,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|