from __future__ import annotations import argparse from pathlib import Path def _build_model(model_name: str, pretrain_weights: str | None = None): # Import here so `--help` works without heavy imports. from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall model_name = model_name.lower().strip() if model_name == "nano": return RFDETRNano(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRNano() if model_name == "small": return RFDETRSmall(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRSmall() if model_name == "medium": return RFDETRMedium(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRMedium() if model_name == "base": return RFDETRBase(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRBase() raise ValueError("--model must be one of: nano, small, medium, base") def main() -> int: parser = argparse.ArgumentParser(description="Fine-tune RF-DETR on a COCO-format knot dataset.") parser.add_argument("--dataset-dir", type=Path, required=True, help="Dataset root containing train/valid/test") parser.add_argument("--output-dir", type=Path, required=True, help="Directory where checkpoints/logs are written") parser.add_argument("--model", default="medium", choices=["nano", "small", "medium", "base"], help="Checkpoint size") parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--grad-accum-steps", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--resume", type=Path, default=None, help="Path to checkpoint.pth to resume training") parser.add_argument( "--pretrain-weights", type=Path, default=None, help="Optional: start from a specific weights file instead of the default COCO pretrain", ) parser.add_argument("--early-stopping", action="store_true", help="Enable early stopping on validation mAP") parser.add_argument("--tensorboard", action="store_true", help="Enable TensorBoard logging") parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging") args = parser.parse_args() dataset_dir: Path = args.dataset_dir if not dataset_dir.exists(): raise SystemExit(f"Dataset dir not found: {dataset_dir}") output_dir: Path = args.output_dir output_dir.mkdir(parents=True, exist_ok=True) model = _build_model( args.model, pretrain_weights=str(args.pretrain_weights) if args.pretrain_weights else None, ) # Train. RF-DETR handles reading COCO split annotation files. model.train( dataset_dir=str(dataset_dir), epochs=args.epochs, batch_size=args.batch_size, grad_accum_steps=args.grad_accum_steps, lr=args.lr, output_dir=str(output_dir), resume=str(args.resume) if args.resume else None, early_stopping=args.early_stopping, tensorboard=args.tensorboard, wandb=args.wandb, ) print(f"Training complete. Outputs in: {output_dir}") return 0 if __name__ == "__main__": raise SystemExit(main())