""" Train RT-DETR for knot detection (Apache 2.0 license - free for commercial use). RT-DETR is a real-time transformer detector that works well on edge devices like OAK cameras. Usage: python train_rtdetr.py --dataset-dir dataset_prepared --model rtdetr-r18 --epochs 100 """ import argparse from pathlib import Path import torch def train_rtdetr( dataset_dir: Path, output_dir: Path, model_name: str = "rtdetr-r18", epochs: int = 100, batch_size: int = 8, img_size: int = 640, learning_rate: float = 1e-4, ): """ Train RT-DETR model. Args: dataset_dir: Path to dataset with train/valid/test splits output_dir: Where to save checkpoints model_name: One of ['rtdetr-r18', 'rtdetr-r34', 'rtdetr-r50', 'rtdetr-l'] epochs: Number of training epochs batch_size: Batch size img_size: Input image size learning_rate: Learning rate """ from ultralytics import RTDETR # Validate dataset structure train_dir = dataset_dir / "train" valid_dir = dataset_dir / "valid" if not train_dir.exists() or not valid_dir.exists(): raise ValueError(f"Dataset must have train/ and valid/ directories") train_ann = train_dir / "_annotations.coco.json" valid_ann = valid_dir / "_annotations.coco.json" if not train_ann.exists(): raise ValueError(f"Missing {train_ann}") if not valid_ann.exists(): raise ValueError(f"Missing {valid_ann}") # Create output directory output_dir.mkdir(parents=True, exist_ok=True) # Create data config file for RT-DETR data_yaml = output_dir / "data.yaml" with data_yaml.open("w") as f: f.write(f"""path: {dataset_dir.absolute()} train: train val: valid nc: 1 names: ['knot'] """) print(f"\n{'='*60}") print(f"Training RT-DETR-{model_name}") print(f"Dataset: {dataset_dir}") print(f"Output: {output_dir}") print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") print(f"{'='*60}\n") # Map model name to pretrained weights model_map = { "rtdetr-r18": "rtdetr-l.pt", # Use available large model as r18 substitute "rtdetr-r34": "rtdetr-l.pt", # Use available large model as r34 substitute "rtdetr-r50": "rtdetr-l.pt", # Use available large model as r50 substitute "rtdetr-l": "rtdetr-l.pt", } if model_name not in model_map: raise ValueError(f"Model must be one of {list(model_map.keys())}") # Initialize model (Ultralytics will auto-download pretrained weights) model = RTDETR(model_map[model_name]) # Train results = model.train( data=str(data_yaml), epochs=epochs, batch=batch_size, imgsz=img_size, lr0=learning_rate, project=str(output_dir), name="training", exist_ok=True, patience=20, # Early stopping save=True, save_period=10, # Save every 10 epochs plots=True, device=0 if torch.cuda.is_available() else "cpu", ) print(f"\n{'='*60}") print(f"✓ Training complete!") print(f"Best weights: {output_dir / 'training/weights/best.pt'}") print(f"Last weights: {output_dir / 'training/weights/last.pt'}") print(f"{'='*60}\n") return results def main(): parser = argparse.ArgumentParser(description="Train RT-DETR for knot detection") parser.add_argument( "--dataset-dir", type=Path, required=True, help="Path to dataset directory with train/valid/test splits" ) parser.add_argument( "--output-dir", type=Path, default=Path("runs/rtdetr_training"), help="Output directory for checkpoints and logs" ) parser.add_argument( "--model", type=str, choices=["rtdetr-r18", "rtdetr-r34", "rtdetr-r50", "rtdetr-l"], default="rtdetr-r18", help="RT-DETR model variant (r18=smallest/fastest, l=largest/most accurate)" ) parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") parser.add_argument("--batch-size", type=int, default=8, help="Batch size") parser.add_argument("--img-size", type=int, default=640, help="Input image size") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") args = parser.parse_args() train_rtdetr( dataset_dir=args.dataset_dir, output_dir=args.output_dir, model_name=args.model, epochs=args.epochs, batch_size=args.batch_size, img_size=args.img_size, learning_rate=args.lr, ) if __name__ == "__main__": main()