82 lines
3.2 KiB
Python
82 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
|
|
def _build_model(model_name: str, pretrain_weights: str | None = None):
|
|
# Import here so `--help` works without heavy imports.
|
|
from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall
|
|
|
|
model_name = model_name.lower().strip()
|
|
|
|
if model_name == "nano":
|
|
return RFDETRNano(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRNano()
|
|
if model_name == "small":
|
|
return RFDETRSmall(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRSmall()
|
|
if model_name == "medium":
|
|
return RFDETRMedium(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRMedium()
|
|
if model_name == "base":
|
|
return RFDETRBase(pretrain_weights=pretrain_weights) if pretrain_weights else RFDETRBase()
|
|
|
|
raise ValueError("--model must be one of: nano, small, medium, base")
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Fine-tune RF-DETR on a COCO-format knot dataset.")
|
|
parser.add_argument("--dataset-dir", type=Path, required=True, help="Dataset root containing train/valid/test")
|
|
parser.add_argument("--output-dir", type=Path, required=True, help="Directory where checkpoints/logs are written")
|
|
parser.add_argument("--model", default="medium", choices=["nano", "small", "medium", "base"], help="Checkpoint size")
|
|
|
|
parser.add_argument("--epochs", type=int, default=50)
|
|
parser.add_argument("--batch-size", type=int, default=4)
|
|
parser.add_argument("--grad-accum-steps", type=int, default=4)
|
|
parser.add_argument("--lr", type=float, default=1e-4)
|
|
|
|
parser.add_argument("--resume", type=Path, default=None, help="Path to checkpoint.pth to resume training")
|
|
parser.add_argument(
|
|
"--pretrain-weights",
|
|
type=Path,
|
|
default=None,
|
|
help="Optional: start from a specific weights file instead of the default COCO pretrain",
|
|
)
|
|
|
|
parser.add_argument("--early-stopping", action="store_true", help="Enable early stopping on validation mAP")
|
|
parser.add_argument("--tensorboard", action="store_true", help="Enable TensorBoard logging")
|
|
parser.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging")
|
|
|
|
args = parser.parse_args()
|
|
|
|
dataset_dir: Path = args.dataset_dir
|
|
if not dataset_dir.exists():
|
|
raise SystemExit(f"Dataset dir not found: {dataset_dir}")
|
|
|
|
output_dir: Path = args.output_dir
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model = _build_model(
|
|
args.model,
|
|
pretrain_weights=str(args.pretrain_weights) if args.pretrain_weights else None,
|
|
)
|
|
|
|
# Train. RF-DETR handles reading COCO split annotation files.
|
|
model.train(
|
|
dataset_dir=str(dataset_dir),
|
|
epochs=args.epochs,
|
|
batch_size=args.batch_size,
|
|
grad_accum_steps=args.grad_accum_steps,
|
|
lr=args.lr,
|
|
output_dir=str(output_dir),
|
|
resume=str(args.resume) if args.resume else None,
|
|
early_stopping=args.early_stopping,
|
|
tensorboard=args.tensorboard,
|
|
wandb=args.wandb,
|
|
)
|
|
|
|
print(f"Training complete. Outputs in: {output_dir}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|