Files
saw_mill_knot_detection/train_rfdetr.py

82 lines
3.2 KiB
Python
Raw Normal View History

from __future__ import annotations
import argparse
from pathlib import Path
def _build_model(model_name: str, pretrain_weights: str | None = None):
# Import here so `--help` works without heavy imports.
from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall
model_name = model_name.lower().strip()
if model_name == "nano":
return RFDETRNano(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRNano()
if model_name == "small":
return RFDETRSmall(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRSmall()
if model_name == "medium":
return RFDETRMedium(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRMedium()
if model_name == "base":
return RFDETRBase(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRBase()
raise ValueError("--model must be one of: nano, small, medium, base")
def main() -> int:
parser = argparse.ArgumentParser(description="Fine-tune RF-DETR on a COCO-format knot dataset.")
parser.add_argument("--dataset-dir", type=Path, required=True, help="Dataset root containing train/valid/test")
parser.add_argument("--output-dir", type=Path, required=True, help="Directory where checkpoints/logs are written")
parser.add_argument("--model", default="medium", choices=["nano", "small", "medium", "base"], help="Checkpoint size")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--grad-accum-steps", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--resume", type=Path, default=None, help="Path to checkpoint.pth to resume training")
parser.add_argument(
"--pretrain-weights",
type=Path,
default=None,
help="Optional: start from a specific weights file instead of the default COCO pretrain",
)
parser.add_argument("--early-stopping", action="store_true", help="Enable early stopping on validation mAP")
parser.add_argument("--tensorboard", action="store_true", help="Enable TensorBoard logging")
parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging")
args = parser.parse_args()
dataset_dir: Path = args.dataset_dir
if not dataset_dir.exists():
raise SystemExit(f"Dataset dir not found: {dataset_dir}")
output_dir: Path = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
model = _build_model(
args.model,
pretrain_weights=str(args.pretrain_weights) if args.pretrain_weights else None,
)
# Train. RF-DETR handles reading COCO split annotation files.
model.train(
dataset_dir=str(dataset_dir),
epochs=args.epochs,
batch_size=args.batch_size,
grad_accum_steps=args.grad_accum_steps,
lr=args.lr,
output_dir=str(output_dir),
resume=str(args.resume) if args.resume else None,
early_stopping=args.early_stopping,
tensorboard=args.tensorboard,
wandb=args.wandb,
)
print(f"Training complete. Outputs in: {output_dir}")
return 0
if __name__ == "__main__":
raise SystemExit(main())