""" Train YOLOX for knot detection (MIT license - free for commercial use). YOLOX is from Megvii, designed for real-time detection. Usage: python train_yolox.py --dataset-dir dataset_prepared --model yolox-nano --epochs 100 """ import argparse from pathlib import Path import torch def train_yolox( dataset_dir: Path, output_dir: Path, model_name: str = "yolox-nano", epochs: int = 100, batch_size: int = 8, img_size: int = 640, learning_rate: float = 1e-3, ): """ Train YOLOX model using Ultralytics (has YOLOX support). Args: dataset_dir: Path to dataset with train/valid/test splits output_dir: Where to save checkpoints model_name: One of ['yolox-nano', 'yolox-tiny', 'yolox-s', 'yolox-m', 'yolox-l'] epochs: Number of training epochs batch_size: Batch size img_size: Input image size learning_rate: Learning rate """ from ultralytics import YOLO # Validate dataset structure train_dir = dataset_dir / "train" valid_dir = dataset_dir / "valid" if not train_dir.exists() or not valid_dir.exists(): raise ValueError(f"Dataset must have train/ and valid/ directories") # Check for data.yaml data_yaml = dataset_dir / "data.yaml" if not data_yaml.exists(): raise ValueError(f"Missing {data_yaml}. Run reorganize_dataset.py first!") # Create output directory output_dir.mkdir(parents=True, exist_ok=True) print(f"\n{'='*60}") print(f"Training YOLOX-{model_name}") print(f"Dataset: {dataset_dir}") print(f"Output: {output_dir}") print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") print(f"{'='*60}\n") # Map model names to YOLO format model_map = { "yolox-nano": "yolox_n.pt", "yolox-tiny": "yolox_tiny.pt", "yolox-s": "yolox_s.pt", "yolox-m": "yolox_m.pt", "yolox-l": "yolox_l.pt", "yolox-x": "yolox_x.pt", } if model_name not in model_map: raise ValueError(f"Model must be one of {list(model_map.keys())}") # Note: Ultralytics doesn't have native YOLOX support, so we'll use YOLOv8 # as the closest alternative with similar architecture print("Note: Using YOLOv8 as Ultralytics doesn't directly support YOLOX") print("YOLOv8 has similar performance and better maintained\n") # Map to YOLOv8 equivalents yolov8_map = { "yolox-nano": "yolov8n.pt", "yolox-tiny": "yolov8n.pt", "yolox-s": "yolov8s.pt", "yolox-m": "yolov8m.pt", "yolox-l": "yolov8l.pt", "yolox-x": "yolov8x.pt", } # Initialize model model = YOLO(yolov8_map[model_name]) # Train results = model.train( data=str(data_yaml), epochs=epochs, batch=batch_size, imgsz=img_size, lr0=learning_rate, project=str(output_dir), name="training", exist_ok=True, patience=20, # Early stopping save=True, save_period=10, # Save every 10 epochs plots=True, device=0 if torch.cuda.is_available() else "cpu", ) print(f"\n{'='*60}") print(f"✓ Training complete!") print(f"Best weights: {output_dir / 'training/weights/best.pt'}") print(f"Last weights: {output_dir / 'training/weights/last.pt'}") print(f"{'='*60}\n") return results def main(): parser = argparse.ArgumentParser(description="Train YOLOX for knot detection") parser.add_argument( "--dataset-dir", type=Path, required=True, help="Path to dataset directory with train/valid/test splits" ) parser.add_argument( "--output-dir", type=Path, default=Path("runs/yolox_training"), help="Output directory for checkpoints and logs" ) parser.add_argument( "--model", type=str, choices=["yolox-nano", "yolox-tiny", "yolox-s", "yolox-m", "yolox-l", "yolox-x"], default="yolox-nano", help="YOLOX model variant (nano=smallest/fastest, x=largest/most accurate)" ) parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") parser.add_argument("--batch-size", type=int, default=8, help="Batch size") parser.add_argument("--img-size", type=int, default=640, help="Input image size") parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") args = parser.parse_args() train_yolox( dataset_dir=args.dataset_dir, output_dir=args.output_dir, model_name=args.model, epochs=args.epochs, batch_size=args.batch_size, img_size=args.img_size, learning_rate=args.lr, ) if __name__ == "__main__": main()