step working on ease of deployment for oak d
This commit is contained in:
@ -41,6 +41,8 @@ class AnnotationApp:
|
|||||||
def __init__(self, images_dir: Path | None = None, model_weights: Path | None = None):
|
def __init__(self, images_dir: Path | None = None, model_weights: Path | None = None):
|
||||||
self.images_dir = images_dir if images_dir else Path.cwd()
|
self.images_dir = images_dir if images_dir else Path.cwd()
|
||||||
self.current_model_path = model_weights
|
self.current_model_path = model_weights
|
||||||
|
self.current_model_type = None # Track model type: 'rf-detr', 'rt-detr', 'yolov6', 'yolox'
|
||||||
|
self.available_models = [] # List of discovered models for quick switching
|
||||||
self.image_paths = []
|
self.image_paths = []
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
self.annotations = {} # image_name -> list of boxes
|
self.annotations = {} # image_name -> list of boxes
|
||||||
@ -74,28 +76,257 @@ class AnnotationApp:
|
|||||||
|
|
||||||
return f"✓ Loaded {len(self.image_paths)} images from {images_dir}"
|
return f"✓ Loaded {len(self.image_paths)} images from {images_dir}"
|
||||||
|
|
||||||
def _load_model(self, weights_path: Path):
|
def find_best_weights(self, directory: Path) -> tuple[Path | None, str | None]:
|
||||||
"""Load YOLO/YOLOX model for auto-labeling (Ultralytics format)."""
|
"""Find the best weights file in a directory based on model type detection."""
|
||||||
|
if not directory.exists():
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Check for RF-DETR weights (checkpoint_best_total.pth)
|
||||||
|
rf_detr_weights = directory / "checkpoint_best_total.pth"
|
||||||
|
if rf_detr_weights.exists():
|
||||||
|
return rf_detr_weights, "rf-detr"
|
||||||
|
|
||||||
|
# Check for Ultralytics weights (best.pt) in weights/ subdirectory
|
||||||
|
ultralytics_weights = directory / "weights" / "best.pt"
|
||||||
|
if ultralytics_weights.exists():
|
||||||
|
# Try to determine specific type from directory name or other clues
|
||||||
|
dir_name = directory.name.lower()
|
||||||
|
if "rtdetr" in dir_name:
|
||||||
|
return ultralytics_weights, "rt-detr"
|
||||||
|
elif "yolov6" in dir_name:
|
||||||
|
return ultralytics_weights, "yolov6"
|
||||||
|
elif "yolox" in dir_name:
|
||||||
|
return ultralytics_weights, "yolox"
|
||||||
|
else:
|
||||||
|
# Default to rt-detr for ultralytics models
|
||||||
|
return ultralytics_weights, "rt-detr"
|
||||||
|
|
||||||
|
# Check for Ultralytics weights in training/weights/ subdirectory (YOLOv6/YOLOX format)
|
||||||
|
training_weights = directory / "training" / "weights" / "best.pt"
|
||||||
|
if training_weights.exists():
|
||||||
|
dir_name = directory.name.lower()
|
||||||
|
if "rtdetr" in dir_name:
|
||||||
|
return training_weights, "rt-detr"
|
||||||
|
elif "yolov6" in dir_name:
|
||||||
|
return training_weights, "yolov6"
|
||||||
|
elif "yolox" in dir_name:
|
||||||
|
return training_weights, "yolox"
|
||||||
|
else:
|
||||||
|
# Default to yolox for training/weights structure
|
||||||
|
return training_weights, "yolox"
|
||||||
|
|
||||||
|
# Check for direct best.pt in directory
|
||||||
|
direct_best = directory / "best.pt"
|
||||||
|
if direct_best.exists():
|
||||||
|
return direct_best, "rt-detr" # Default assumption
|
||||||
|
|
||||||
|
# Check for any .pth or .pt files as fallback
|
||||||
|
pth_files = list(directory.glob("*.pth")) + list(directory.glob("*.pt"))
|
||||||
|
if pth_files:
|
||||||
|
# Prefer files with "best" in name
|
||||||
|
best_files = [f for f in pth_files if "best" in f.name.lower()]
|
||||||
|
if best_files:
|
||||||
|
return best_files[0], self._guess_model_type_from_path(best_files[0])
|
||||||
|
else:
|
||||||
|
return pth_files[0], self._guess_model_type_from_path(pth_files[0])
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _guess_model_type_from_path(self, path: Path) -> str:
|
||||||
|
"""Guess model type from file path."""
|
||||||
|
path_str = str(path).lower()
|
||||||
|
if "rf" in path_str or "checkpoint" in path_str:
|
||||||
|
return "rf-detr"
|
||||||
|
elif "rtdetr" in path_str:
|
||||||
|
return "rt-detr"
|
||||||
|
elif "yolov6" in path_str:
|
||||||
|
return "yolov6"
|
||||||
|
elif "yolox" in path_str:
|
||||||
|
return "yolox"
|
||||||
|
else:
|
||||||
|
return "rt-detr" # Default
|
||||||
|
|
||||||
|
def _load_model(self, weights_path: Path, model_type: str = None):
|
||||||
|
"""Load model for auto-labeling based on type."""
|
||||||
try:
|
try:
|
||||||
from ultralytics import YOLO
|
import torch
|
||||||
print(f"Loading model from {weights_path}...")
|
|
||||||
self.model = YOLO(str(weights_path))
|
if model_type is None:
|
||||||
self.current_model_path = weights_path
|
model_type = self._guess_model_type_from_path(weights_path)
|
||||||
print("✓ Model loaded")
|
|
||||||
return f"✓ Model loaded from {weights_path.name}"
|
print(f"Loading {model_type} model from {weights_path}...")
|
||||||
|
|
||||||
|
# Check for GPU availability
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
if model_type == "rf-detr":
|
||||||
|
# RF-DETR uses custom loader - try different model sizes
|
||||||
|
from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall
|
||||||
|
|
||||||
|
# Try to determine model size from checkpoint or use nano as default
|
||||||
|
checkpoint = torch.load(weights_path, map_location='cpu', weights_only=False)
|
||||||
|
if 'model' in checkpoint:
|
||||||
|
# Training checkpoint - check the model size from the state dict
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
# Look for clues about model size in the state dict keys
|
||||||
|
if any('backbone.0.encoder.encoder.embeddings.position_embeddings' in key for key in state_dict.keys()):
|
||||||
|
# Try different model sizes to find the right one
|
||||||
|
models_to_try = [
|
||||||
|
("nano", RFDETRNano),
|
||||||
|
("small", RFDETRSmall),
|
||||||
|
("medium", RFDETRMedium),
|
||||||
|
("base", RFDETRBase)
|
||||||
|
]
|
||||||
|
|
||||||
|
for size_name, model_class in models_to_try:
|
||||||
|
try:
|
||||||
|
self.model = model_class()
|
||||||
|
# Try loading with strict=False to handle mismatches
|
||||||
|
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
if len(missing_keys) < len(self.model.state_dict()): # Some keys matched
|
||||||
|
print(f"✓ Loaded RF-DETR {size_name} model (with {len(missing_keys)} missing keys)")
|
||||||
|
# Move to GPU if available
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"⚠ Could not load model: {e}"
|
continue
|
||||||
|
else:
|
||||||
|
raise Exception("Could not load checkpoint with any RF-DETR model size")
|
||||||
|
else:
|
||||||
|
# Direct weights file
|
||||||
|
self.model = RFDETRNano(pretrain_weights=str(weights_path))
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
else:
|
||||||
|
# Direct weights file
|
||||||
|
self.model = RFDETRNano(pretrain_weights=str(weights_path))
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
else:
|
||||||
|
# RT-DETR, YOLOv6, YOLOX all use Ultralytics
|
||||||
|
if model_type == "rt-detr":
|
||||||
|
from ultralytics import RTDETR
|
||||||
|
self.model = RTDETR(str(weights_path))
|
||||||
|
else:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
self.model = YOLO(str(weights_path))
|
||||||
|
|
||||||
|
# Ultralytics models should automatically use GPU if available
|
||||||
|
# but let's ensure they're on the right device
|
||||||
|
if hasattr(self.model, 'to'):
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
self.current_model_path = weights_path
|
||||||
|
self.current_model_type = model_type
|
||||||
|
|
||||||
|
# Add to available models for quick switching
|
||||||
|
model_display = f"Custom: {weights_path.name} ({model_type.upper()})"
|
||||||
|
existing_model = next((m for m in self.available_models if m['path'] == weights_path), None)
|
||||||
|
if not existing_model:
|
||||||
|
self.available_models.append({
|
||||||
|
"path": weights_path,
|
||||||
|
"type": model_type,
|
||||||
|
"dir": weights_path.parent.name,
|
||||||
|
"display": model_display
|
||||||
|
})
|
||||||
|
|
||||||
|
print("✓ Model loaded")
|
||||||
|
return f"✓ {model_type.upper()} model loaded from {weights_path.name}"
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"⚠ Could not load {model_type or 'model'}: {e}"
|
||||||
print(error_msg)
|
print(error_msg)
|
||||||
self.model = None
|
self.model = None
|
||||||
|
self.current_model_type = None
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
def load_new_model(self, weights_path: str) -> str:
|
def load_new_model(self, weights_path: str, model_type: str = "Auto-detect") -> str:
|
||||||
"""Load a new model from the GUI."""
|
"""Load a new model from the GUI."""
|
||||||
path = Path(weights_path)
|
path = Path(weights_path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return f"❌ File not found: {weights_path}"
|
return f"❌ File not found: {weights_path}"
|
||||||
|
|
||||||
return self._load_model(path)
|
# Convert dropdown value to internal type
|
||||||
|
if model_type == "Auto-detect":
|
||||||
|
model_type = None
|
||||||
|
elif model_type == "rf-detr":
|
||||||
|
model_type = "rf-detr"
|
||||||
|
elif model_type == "rt-detr":
|
||||||
|
model_type = "rt-detr"
|
||||||
|
elif model_type == "yolov6":
|
||||||
|
model_type = "yolov6"
|
||||||
|
elif model_type == "yolox":
|
||||||
|
model_type = "yolox"
|
||||||
|
|
||||||
|
return self._load_model(path, model_type)
|
||||||
|
|
||||||
|
def load_model_from_directory(self, directory_path: str) -> str:
|
||||||
|
"""Load the best model found in a directory."""
|
||||||
|
path = Path(directory_path)
|
||||||
|
if not path.exists():
|
||||||
|
return f"❌ Directory not found: {directory_path}"
|
||||||
|
|
||||||
|
if not path.is_dir():
|
||||||
|
return f"❌ Not a directory: {directory_path}"
|
||||||
|
|
||||||
|
weights_path, detected_type = self.find_best_weights(path)
|
||||||
|
if weights_path is None:
|
||||||
|
return f"❌ No model weights found in {directory_path}"
|
||||||
|
|
||||||
|
return self._load_model(weights_path, detected_type)
|
||||||
|
|
||||||
|
def scan_for_models(self, return_info: bool = True) -> str:
|
||||||
|
"""Scan for available trained models in common directories."""
|
||||||
|
runs_dir = Path("runs")
|
||||||
|
available_models = []
|
||||||
|
|
||||||
|
if runs_dir.exists():
|
||||||
|
for subdir in runs_dir.iterdir():
|
||||||
|
if subdir.is_dir():
|
||||||
|
weights_path, model_type = self.find_best_weights(subdir)
|
||||||
|
if weights_path:
|
||||||
|
available_models.append({
|
||||||
|
"path": weights_path,
|
||||||
|
"type": model_type,
|
||||||
|
"dir": subdir.name,
|
||||||
|
"display": f"{subdir.name} ({model_type.upper()})"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Store available models for quick access
|
||||||
|
self.available_models = available_models
|
||||||
|
|
||||||
|
if not return_info:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if not available_models:
|
||||||
|
return "❌ No trained models found in 'runs/' directory"
|
||||||
|
|
||||||
|
# Format as readable list
|
||||||
|
lines = ["📂 Available Models:"]
|
||||||
|
for i, model in enumerate(available_models, 1):
|
||||||
|
lines.append(f"{i}. {model['dir']} → {model['path'].name} ({model['type'].upper()})")
|
||||||
|
|
||||||
|
lines.append("\n💡 Use the Model Selector dropdown above to quickly switch models")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def get_available_models_list(self) -> list:
|
||||||
|
"""Get list of available models for dropdown."""
|
||||||
|
if not self.available_models:
|
||||||
|
self.scan_for_models(return_info=False) # This will populate self.available_models
|
||||||
|
|
||||||
|
if not self.available_models:
|
||||||
|
return ["No models found - click '🔍 Scan for Models'"]
|
||||||
|
|
||||||
|
return [model['display'] for model in self.available_models]
|
||||||
|
|
||||||
|
def load_model_by_index(self, model_display: str) -> str:
|
||||||
|
"""Load a model by its display name from the available models list."""
|
||||||
|
if not hasattr(self, 'available_models') or not self.available_models:
|
||||||
|
return "❌ No models available. Click '🔍 Scan for Models' first."
|
||||||
|
|
||||||
|
for model in self.available_models:
|
||||||
|
if model['display'] == model_display:
|
||||||
|
return self._load_model(model['path'], model['type'])
|
||||||
|
|
||||||
|
return f"❌ Model '{model_display}' not found"
|
||||||
|
|
||||||
def load_new_images_dir(self, images_dir: str) -> tuple[Image.Image | None, str, str]:
|
def load_new_images_dir(self, images_dir: str) -> tuple[Image.Image | None, str, str]:
|
||||||
"""Load a new images directory from the GUI."""
|
"""Load a new images directory from the GUI."""
|
||||||
@ -122,7 +353,8 @@ class AnnotationApp:
|
|||||||
def get_current_model_info(self) -> str:
|
def get_current_model_info(self) -> str:
|
||||||
"""Get info about currently loaded model."""
|
"""Get info about currently loaded model."""
|
||||||
if self.model and self.current_model_path:
|
if self.model and self.current_model_path:
|
||||||
return f"📦 Loaded: {self.current_model_path}"
|
type_info = f" ({self.current_model_type.upper()})" if self.current_model_type else ""
|
||||||
|
return f"📦 Loaded: {self.current_model_path}{type_info}"
|
||||||
elif self.model:
|
elif self.model:
|
||||||
return "📦 Model loaded (pretrained)"
|
return "📦 Model loaded (pretrained)"
|
||||||
else:
|
else:
|
||||||
@ -160,52 +392,67 @@ class AnnotationApp:
|
|||||||
return img_draw
|
return img_draw
|
||||||
|
|
||||||
def auto_label_current(self, threshold: float = 0.5) -> tuple[Image.Image, str, str]:
|
def auto_label_current(self, threshold: float = 0.5) -> tuple[Image.Image, str, str]:
|
||||||
"""Auto-label current image with model."""
|
"""Auto-label current image using loaded model."""
|
||||||
if not self.model:
|
if not self.model:
|
||||||
img, filename = self.get_current_image()
|
return None, "", "❌ No model loaded"
|
||||||
info = f"⚠ No model loaded | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
|
|
||||||
return img, "", info
|
|
||||||
|
|
||||||
img, filename = self.get_current_image()
|
img, filename = self.get_current_image()
|
||||||
if not img:
|
if not img:
|
||||||
return None, "", "No images"
|
return None, "", "No images"
|
||||||
|
|
||||||
# Run inference with Ultralytics YOLO
|
try:
|
||||||
results = self.model.predict(img, conf=threshold, verbose=False)
|
# Run inference based on model type
|
||||||
|
if self.current_model_type == "rf-detr":
|
||||||
# Convert to our format
|
# RF-DETR custom prediction
|
||||||
|
detections = self.model.predict(img, threshold=threshold)
|
||||||
boxes = []
|
boxes = []
|
||||||
if len(results) > 0:
|
for i in range(len(detections)):
|
||||||
result = results[0] # First image result
|
xyxy = detections.xyxy[i]
|
||||||
if result.boxes is not None and len(result.boxes) > 0:
|
conf = float(detections.confidence[i]) if detections.confidence is not None else 1.0
|
||||||
for box in result.boxes:
|
x1, y1, x2, y2 = xyxy
|
||||||
xyxy = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
|
|
||||||
conf = float(box.conf[0].cpu().numpy())
|
|
||||||
cls = int(box.cls[0].cpu().numpy())
|
|
||||||
|
|
||||||
# Get class name if available
|
|
||||||
label = result.names.get(cls, f"class_{cls}") if hasattr(result, 'names') else f"class_{cls}"
|
|
||||||
|
|
||||||
boxes.append({
|
boxes.append({
|
||||||
"bbox": xyxy,
|
"bbox": [float(x1), float(y1), float(x2), float(y2)],
|
||||||
"label": label,
|
"label": "knot",
|
||||||
|
"confidence": conf,
|
||||||
|
"source": "auto"
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Ultralytics models (RT-DETR, YOLOv6, YOLOX)
|
||||||
|
results = self.model.predict(
|
||||||
|
source=img,
|
||||||
|
conf=threshold,
|
||||||
|
save=False,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
for result in results:
|
||||||
|
for box in result.boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
conf = float(box.conf[0])
|
||||||
|
boxes.append({
|
||||||
|
"bbox": [x1, y1, x2, y2],
|
||||||
|
"label": "knot",
|
||||||
"confidence": conf,
|
"confidence": conf,
|
||||||
"source": "auto"
|
"source": "auto"
|
||||||
})
|
})
|
||||||
|
|
||||||
# Save
|
# Add to existing annotations
|
||||||
self.annotations[filename] = boxes
|
if filename not in self.annotations:
|
||||||
|
self.annotations[filename] = []
|
||||||
|
self.annotations[filename].extend(boxes)
|
||||||
self._save_annotations()
|
self._save_annotations()
|
||||||
|
|
||||||
# Draw boxes on image
|
# Redraw
|
||||||
img_with_boxes = self.draw_boxes_on_image(img, boxes)
|
img_with_boxes = self.draw_boxes_on_image(img, self.annotations[filename])
|
||||||
|
boxes_text = self._format_boxes_text(self.annotations[filename])
|
||||||
# Info with image index
|
info = f"🤖 Auto-labeled: {len(boxes)} detections | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
|
||||||
info = f"✓ Auto-labeled: {len(boxes)} boxes detected | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
|
|
||||||
boxes_text = self._format_boxes_text(boxes)
|
|
||||||
|
|
||||||
return img_with_boxes, boxes_text, info
|
return img_with_boxes, boxes_text, info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return img, self._format_boxes_text(self.annotations.get(filename, [])), f"❌ Auto-label failed: {e}"
|
||||||
|
|
||||||
def _format_boxes_text(self, boxes: list[dict]) -> str:
|
def _format_boxes_text(self, boxes: list[dict]) -> str:
|
||||||
"""Format boxes for display."""
|
"""Format boxes for display."""
|
||||||
if not boxes:
|
if not boxes:
|
||||||
@ -298,6 +545,68 @@ class AnnotationApp:
|
|||||||
|
|
||||||
return img, boxes_text, info
|
return img, boxes_text, info
|
||||||
|
|
||||||
|
def auto_label_current(self, threshold: float = 0.5) -> tuple[Image.Image, str, str]:
|
||||||
|
"""Auto-label current image using loaded model."""
|
||||||
|
if not self.model:
|
||||||
|
return None, "", "❌ No model loaded"
|
||||||
|
|
||||||
|
img, filename = self.get_current_image()
|
||||||
|
if not img:
|
||||||
|
return None, "", "No images"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run inference based on model type
|
||||||
|
if self.current_model_type == "rf-detr":
|
||||||
|
# RF-DETR custom prediction
|
||||||
|
detections = self.model.predict(img, threshold=threshold)
|
||||||
|
boxes = []
|
||||||
|
for i in range(len(detections)):
|
||||||
|
xyxy = detections.xyxy[i]
|
||||||
|
conf = float(detections.confidence[i]) if detections.confidence is not None else 1.0
|
||||||
|
x1, y1, x2, y2 = xyxy
|
||||||
|
boxes.append({
|
||||||
|
"bbox": [float(x1), float(y1), float(x2), float(y2)],
|
||||||
|
"label": "knot",
|
||||||
|
"confidence": conf,
|
||||||
|
"source": "auto"
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Ultralytics models (RT-DETR, YOLOv6, YOLOX)
|
||||||
|
results = self.model.predict(
|
||||||
|
source=img,
|
||||||
|
conf=threshold,
|
||||||
|
save=False,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
for result in results:
|
||||||
|
for box in result.boxes:
|
||||||
|
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||||
|
conf = float(box.conf[0])
|
||||||
|
boxes.append({
|
||||||
|
"bbox": [x1, y1, x2, y2],
|
||||||
|
"label": "knot",
|
||||||
|
"confidence": conf,
|
||||||
|
"source": "auto"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add to existing annotations
|
||||||
|
if filename not in self.annotations:
|
||||||
|
self.annotations[filename] = []
|
||||||
|
self.annotations[filename].extend(boxes)
|
||||||
|
self._save_annotations()
|
||||||
|
|
||||||
|
# Redraw
|
||||||
|
img_with_boxes = self.draw_boxes_on_image(img, self.annotations[filename])
|
||||||
|
boxes_text = self._format_boxes_text(self.annotations[filename])
|
||||||
|
info = f"🤖 Auto-labeled: {len(boxes)} detections | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
|
||||||
|
|
||||||
|
return img_with_boxes, boxes_text, info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return img, self._format_boxes_text(self.annotations.get(filename, [])), f"❌ Auto-label failed: {e}"
|
||||||
|
|
||||||
def _save_annotations(self):
|
def _save_annotations(self):
|
||||||
"""Save annotations to JSON file."""
|
"""Save annotations to JSON file."""
|
||||||
with self.ann_file.open("w") as f:
|
with self.ann_file.open("w") as f:
|
||||||
@ -439,7 +748,24 @@ class AnnotationApp:
|
|||||||
# Build training command based on framework
|
# Build training command based on framework
|
||||||
venv_python = Path(__file__).parent / ".venv/bin/python"
|
venv_python = Path(__file__).parent / ".venv/bin/python"
|
||||||
|
|
||||||
if framework == "RT-DETR":
|
if framework == "RF-DETR":
|
||||||
|
train_script = Path(__file__).parent / "train_rfdetr.py"
|
||||||
|
# Map sizes: nano->nano, small->small, medium->medium, base->base
|
||||||
|
size_map = {"nano": "nano", "small": "small", "medium": "medium", "base": "base"}
|
||||||
|
model_arg = size_map.get(model_size, "medium")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
str(venv_python),
|
||||||
|
str(train_script),
|
||||||
|
"--dataset-dir", str(dataset_path),
|
||||||
|
"--output-dir", str(output_path),
|
||||||
|
"--model", model_arg,
|
||||||
|
"--epochs", str(epochs),
|
||||||
|
"--batch-size", str(batch_size),
|
||||||
|
"--grad-accum-steps", "2", # Default grad accum
|
||||||
|
"--lr", str(lr)
|
||||||
|
]
|
||||||
|
elif framework == "RT-DETR":
|
||||||
train_script = Path(__file__).parent / "train_rtdetr.py"
|
train_script = Path(__file__).parent / "train_rtdetr.py"
|
||||||
# Map sizes: nano->r18, small->r34, medium->r50, base->l
|
# Map sizes: nano->r18, small->r34, medium->r50, base->l
|
||||||
size_map = {"nano": "rtdetr-r18", "small": "rtdetr-r34", "medium": "rtdetr-r50", "base": "rtdetr-l"}
|
size_map = {"nano": "rtdetr-r18", "small": "rtdetr-r34", "medium": "rtdetr-r50", "base": "rtdetr-l"}
|
||||||
@ -509,9 +835,23 @@ class AnnotationApp:
|
|||||||
if self.training_process.returncode == 0:
|
if self.training_process.returncode == 0:
|
||||||
self.training_status = "✅ Training completed successfully!"
|
self.training_status = "✅ Training completed successfully!"
|
||||||
# Reload model with new weights
|
# Reload model with new weights
|
||||||
|
if framework == "RF-DETR":
|
||||||
|
# RF-DETR uses checkpoint_best_total.pth
|
||||||
best_weights = output_path / "checkpoint_best_total.pth"
|
best_weights = output_path / "checkpoint_best_total.pth"
|
||||||
|
model_type = "rf-detr"
|
||||||
|
elif framework == "RT-DETR":
|
||||||
|
# RT-DETR uses best.pt in weights/ subdirectory (Ultralytics)
|
||||||
|
best_weights = output_path / "weights" / "best.pt"
|
||||||
|
model_type = "rt-detr"
|
||||||
|
elif framework == "YOLOv6":
|
||||||
|
best_weights = output_path / "weights" / "best.pt"
|
||||||
|
model_type = "yolov6"
|
||||||
|
elif framework == "YOLOX":
|
||||||
|
best_weights = output_path / "weights" / "best.pt"
|
||||||
|
model_type = "yolox"
|
||||||
|
|
||||||
if best_weights.exists():
|
if best_weights.exists():
|
||||||
self._load_model(best_weights)
|
self._load_model(best_weights, model_type)
|
||||||
else:
|
else:
|
||||||
self.training_status = f"❌ Training failed (exit code {self.training_process.returncode})"
|
self.training_status = f"❌ Training failed (exit code {self.training_process.returncode})"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -534,6 +874,88 @@ class AnnotationApp:
|
|||||||
return "✓ Training process terminated"
|
return "✓ Training process terminated"
|
||||||
return "⚠️ No training in progress"
|
return "⚠️ No training in progress"
|
||||||
|
|
||||||
|
def export_for_oak_d(self, model_path: str, output_dir: str = "oak_d_export", img_size: int = 640):
|
||||||
|
"""Export trained model for OAK-D camera deployment."""
|
||||||
|
try:
|
||||||
|
weights_path = Path(model_path)
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
|
||||||
|
if not weights_path.exists():
|
||||||
|
return "❌ Model weights not found"
|
||||||
|
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Determine model type
|
||||||
|
model_type = self._guess_model_type_from_path(weights_path)
|
||||||
|
|
||||||
|
print(f"Exporting {model_type} model for OAK-D...")
|
||||||
|
|
||||||
|
if model_type == "rf-detr":
|
||||||
|
# RF-DETR export - use existing export_onnx.py logic
|
||||||
|
from rfdetr import RFDETRBase
|
||||||
|
|
||||||
|
model = RFDETRBase(pretrain_weights=str(weights_path))
|
||||||
|
model.export() # Creates output/model.onnx
|
||||||
|
|
||||||
|
# Move to output directory
|
||||||
|
onnx_source = Path("output/model.onnx")
|
||||||
|
if onnx_source.exists():
|
||||||
|
onnx_dest = output_path / "rf_detr_model.onnx"
|
||||||
|
onnx_source.rename(onnx_dest)
|
||||||
|
|
||||||
|
return f"✓ RF-DETR exported for OAK-D!\n📁 Output: {output_path}\n🔗 Next: Convert ONNX to blob using blobconverter.luxonis.com"
|
||||||
|
else:
|
||||||
|
return "❌ ONNX export failed"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Ultralytics models (RT-DETR, YOLOv6, YOLOX)
|
||||||
|
if model_type == "rt-detr":
|
||||||
|
from ultralytics import RTDETR
|
||||||
|
model = RTDETR(str(weights_path))
|
||||||
|
else:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
model = YOLO(str(weights_path))
|
||||||
|
|
||||||
|
# Export to ONNX
|
||||||
|
onnx_path = model.export(
|
||||||
|
format="onnx",
|
||||||
|
imgsz=img_size,
|
||||||
|
simplify=True,
|
||||||
|
opset=11, # OAK-compatible opset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move ONNX to output directory
|
||||||
|
if Path(onnx_path).exists():
|
||||||
|
final_onnx = output_path / f"{model_type}_model.onnx"
|
||||||
|
Path(onnx_path).rename(final_onnx)
|
||||||
|
onnx_path = final_onnx
|
||||||
|
|
||||||
|
# Try to export to OpenVINO if available
|
||||||
|
try:
|
||||||
|
openvino_path = model.export(
|
||||||
|
format="openvino",
|
||||||
|
imgsz=img_size,
|
||||||
|
half=False, # Use FP32 for better compatibility
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move OpenVINO files to output directory
|
||||||
|
if Path(openvino_path).exists():
|
||||||
|
import shutil
|
||||||
|
openvino_dir = Path(openvino_path)
|
||||||
|
for file in openvino_dir.glob("*"):
|
||||||
|
if file.is_file():
|
||||||
|
shutil.move(str(file), str(output_path / file.name))
|
||||||
|
openvino_dir.rmdir() # Remove empty dir
|
||||||
|
|
||||||
|
return f"✓ {model_type.upper()} exported for OAK-D!\n📁 Output: {output_path}\n🔗 Next: Convert .xml/.bin to blob using blobconverter.luxonis.com"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# OpenVINO not available, just return ONNX
|
||||||
|
return f"✓ {model_type.upper()} exported to ONNX!\n📁 Output: {output_path}\n🔗 Next: Convert ONNX to blob using blobconverter.luxonis.com\n⚠️ OpenVINO not available: {str(e)}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ Export failed: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def create_ui(app: AnnotationApp) -> gr.Blocks:
|
def create_ui(app: AnnotationApp) -> gr.Blocks:
|
||||||
"""Create Gradio UI."""
|
"""Create Gradio UI."""
|
||||||
@ -545,7 +967,7 @@ def create_ui(app: AnnotationApp) -> gr.Blocks:
|
|||||||
|
|
||||||
- Manually annotate images or use **Auto-Label** with your trained model
|
- Manually annotate images or use **Auto-Label** with your trained model
|
||||||
- Export and prepare dataset for training
|
- Export and prepare dataset for training
|
||||||
- Train **RT-DETR, YOLOv6, or YOLOX** (all free for commercial use!)
|
- Train **RF-DETR, RT-DETR, YOLOv6, or YOLOX** (all free for commercial use!)
|
||||||
- Optimized for OAK-D camera deployment
|
- Optimized for OAK-D camera deployment
|
||||||
- Use trained model to auto-label more images
|
- Use trained model to auto-label more images
|
||||||
""")
|
""")
|
||||||
@ -563,12 +985,31 @@ def create_ui(app: AnnotationApp) -> gr.Blocks:
|
|||||||
dir_info = gr.Textbox(label="Current Directory", value=app.get_current_dir_info(), interactive=False)
|
dir_info = gr.Textbox(label="Current Directory", value=app.get_current_dir_info(), interactive=False)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
# Quick Model Selector
|
||||||
|
model_selector = gr.Dropdown(
|
||||||
|
choices=app.get_available_models_list(),
|
||||||
|
value=None,
|
||||||
|
label="🚀 Quick Model Switcher",
|
||||||
|
info="Select from available trained models (refresh with scan)",
|
||||||
|
allow_custom_value=True
|
||||||
|
)
|
||||||
|
quick_load_btn = gr.Button("⚡ Load Selected Model", variant="primary")
|
||||||
|
|
||||||
|
# Manual Model Loading
|
||||||
model_weights_input = gr.Textbox(
|
model_weights_input = gr.Textbox(
|
||||||
label="Model Weights Path",
|
label="Model Weights Path",
|
||||||
value=str(app.current_model_path) if app.current_model_path else "",
|
value=str(app.current_model_path) if app.current_model_path else "",
|
||||||
placeholder="runs/training/checkpoint_best_total.pth"
|
placeholder="runs/training/checkpoint_best_total.pth"
|
||||||
)
|
)
|
||||||
|
model_type_dropdown = gr.Dropdown(
|
||||||
|
choices=["Auto-detect", "rf-detr", "rt-detr", "yolov6", "yolox"],
|
||||||
|
value="Auto-detect",
|
||||||
|
label="Model Type",
|
||||||
|
info="Auto-detect will try to determine from file path"
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
load_model_btn = gr.Button("🤖 Load Model Weights")
|
load_model_btn = gr.Button("🤖 Load Model Weights")
|
||||||
|
scan_models_btn = gr.Button("🔍 Scan for Models")
|
||||||
model_info = gr.Textbox(label="Current Model", value=app.get_current_model_info(), interactive=False)
|
model_info = gr.Textbox(label="Current Model", value=app.get_current_model_info(), interactive=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -613,7 +1054,8 @@ def create_ui(app: AnnotationApp) -> gr.Blocks:
|
|||||||
### Train Object Detection Model
|
### Train Object Detection Model
|
||||||
|
|
||||||
**Choose your framework:**
|
**Choose your framework:**
|
||||||
- **RT-DETR** (Apache 2.0): Modern transformer, great accuracy
|
- **RF-DETR** (MIT): Custom transformer, high accuracy
|
||||||
|
- **RT-DETR** (Apache 2.0): Ultralytics transformer, great accuracy
|
||||||
- **YOLOv6** (MIT): Fast, proven on OAK cameras
|
- **YOLOv6** (MIT): Fast, proven on OAK cameras
|
||||||
- **YOLOX** (MIT): Similar to YOLOv6, slight differences
|
- **YOLOX** (MIT): Similar to YOLOv6, slight differences
|
||||||
|
|
||||||
@ -641,7 +1083,7 @@ def create_ui(app: AnnotationApp) -> gr.Blocks:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("### Training Configuration")
|
gr.Markdown("### Training Configuration")
|
||||||
model_framework = gr.Dropdown(
|
model_framework = gr.Dropdown(
|
||||||
choices=["RT-DETR", "YOLOv6", "YOLOX"],
|
choices=["RF-DETR", "RT-DETR", "YOLOv6", "YOLOX"],
|
||||||
value="RT-DETR",
|
value="RT-DETR",
|
||||||
label="Model Framework",
|
label="Model Framework",
|
||||||
info="All MIT/Apache 2.0 licensed - free for commercial use. Optimized for OAK cameras."
|
info="All MIT/Apache 2.0 licensed - free for commercial use. Optimized for OAK cameras."
|
||||||
@ -679,6 +1121,79 @@ def create_ui(app: AnnotationApp) -> gr.Blocks:
|
|||||||
Check the training log file for detailed progress.
|
Check the training log file for detailed progress.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# OAK-D Deployment tab
|
||||||
|
with gr.Tab("🚀 OAK-D Deployment"):
|
||||||
|
gr.Markdown("""
|
||||||
|
### Deploy Trained Model to OAK-D Camera
|
||||||
|
|
||||||
|
Convert your trained model to work with the **OAK-D 4 Pro** camera for real-time edge inference.
|
||||||
|
|
||||||
|
**Supported Models**: RF-DETR, RT-DETR, YOLOv6, YOLOX
|
||||||
|
|
||||||
|
**Process**:
|
||||||
|
1. Select a trained model from your runs/ directory
|
||||||
|
2. Export to ONNX and OpenVINO formats
|
||||||
|
3. Convert OpenVINO model to blob for OAK-D
|
||||||
|
4. Deploy blob to your OAK-D camera
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
oak_model_selector = gr.Dropdown(
|
||||||
|
choices=app.get_available_models_list(),
|
||||||
|
value=None,
|
||||||
|
label="Select Trained Model",
|
||||||
|
info="Choose a model from your training runs",
|
||||||
|
allow_custom_value=True
|
||||||
|
)
|
||||||
|
oak_output_dir = gr.Textbox(
|
||||||
|
label="Output Directory",
|
||||||
|
value="oak_d_deployment",
|
||||||
|
placeholder="oak_d_deployment"
|
||||||
|
)
|
||||||
|
oak_img_size = gr.Dropdown(
|
||||||
|
choices=[320, 416, 512, 640, 800, 1024],
|
||||||
|
value=640,
|
||||||
|
label="Image Size",
|
||||||
|
info="Input size for the model (should match training)"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
oak_scan_btn = gr.Button("🔍 Scan for Models")
|
||||||
|
oak_export_btn = gr.Button("🚀 Export for OAK-D", variant="primary")
|
||||||
|
|
||||||
|
oak_status = gr.Textbox(
|
||||||
|
label="Export Status",
|
||||||
|
value="Ready to export",
|
||||||
|
lines=4
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("""
|
||||||
|
### 📋 Deployment Instructions
|
||||||
|
|
||||||
|
**After Export:**
|
||||||
|
1. **Test OpenVINO Model** (optional):
|
||||||
|
```bash
|
||||||
|
python -c "from openvino.runtime import Core; core = Core(); model = core.read_model('model.xml'); print('✓ Model loaded')"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Convert to Blob**:
|
||||||
|
- Go to: https://blobconverter.luxonis.com/
|
||||||
|
- Upload your `.xml` and `.bin` files
|
||||||
|
- Select OAK-D device
|
||||||
|
- Download the `.blob` file
|
||||||
|
|
||||||
|
3. **Deploy to OAK-D**:
|
||||||
|
- Use DepthAI Python API
|
||||||
|
- Or use OAK-D examples with your blob
|
||||||
|
|
||||||
|
### 💡 Tips
|
||||||
|
- Use **FP32** for best accuracy (default)
|
||||||
|
- **Nano models** work best on edge devices
|
||||||
|
- Test inference speed vs accuracy trade-off
|
||||||
|
""")
|
||||||
|
|
||||||
# Event handlers
|
# Event handlers
|
||||||
def on_load():
|
def on_load():
|
||||||
return app.load_image("current")
|
return app.load_image("current")
|
||||||
@ -695,8 +1210,28 @@ def create_ui(app: AnnotationApp) -> gr.Blocks:
|
|||||||
|
|
||||||
load_model_btn.click(
|
load_model_btn.click(
|
||||||
app.load_new_model,
|
app.load_new_model,
|
||||||
inputs=[model_weights_input],
|
inputs=[model_weights_input, model_type_dropdown],
|
||||||
outputs=[model_info]
|
outputs=[model_info]
|
||||||
|
).then(
|
||||||
|
app.get_available_models_list,
|
||||||
|
outputs=[model_selector]
|
||||||
|
)
|
||||||
|
|
||||||
|
scan_models_btn.click(
|
||||||
|
app.scan_for_models,
|
||||||
|
outputs=[model_info]
|
||||||
|
).then(
|
||||||
|
app.get_available_models_list,
|
||||||
|
outputs=[model_selector]
|
||||||
|
)
|
||||||
|
|
||||||
|
quick_load_btn.click(
|
||||||
|
app.load_model_by_index,
|
||||||
|
inputs=[model_selector],
|
||||||
|
outputs=[model_info]
|
||||||
|
).then(
|
||||||
|
app.get_available_models_list,
|
||||||
|
outputs=[model_selector]
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_btn.click(
|
prev_btn.click(
|
||||||
@ -760,6 +1295,21 @@ def create_ui(app: AnnotationApp) -> gr.Blocks:
|
|||||||
outputs=[training_status]
|
outputs=[training_status]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# OAK-D Deployment handlers
|
||||||
|
oak_scan_btn.click(
|
||||||
|
app.scan_for_models,
|
||||||
|
outputs=[oak_status]
|
||||||
|
).then(
|
||||||
|
app.get_available_models_list,
|
||||||
|
outputs=[oak_model_selector]
|
||||||
|
)
|
||||||
|
|
||||||
|
oak_export_btn.click(
|
||||||
|
app.export_for_oak_d,
|
||||||
|
inputs=[oak_model_selector, oak_output_dir, oak_img_size],
|
||||||
|
outputs=[oak_status]
|
||||||
|
)
|
||||||
|
|
||||||
# Load first image on start
|
# Load first image on start
|
||||||
demo.load(on_load, outputs=[image_display, boxes_text, info_text])
|
demo.load(on_load, outputs=[image_display, boxes_text, info_text])
|
||||||
|
|
||||||
@ -780,7 +1330,6 @@ def main():
|
|||||||
default=Path(DEFAULT_MODEL_WEIGHTS) if DEFAULT_MODEL_WEIGHTS else None,
|
default=Path(DEFAULT_MODEL_WEIGHTS) if DEFAULT_MODEL_WEIGHTS else None,
|
||||||
help="Default trained model for auto-labeling (can be changed in GUI)"
|
help="Default trained model for auto-labeling (can be changed in GUI)"
|
||||||
)
|
)
|
||||||
parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Port for web interface")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate paths if provided
|
# Validate paths if provided
|
||||||
@ -797,6 +1346,9 @@ def main():
|
|||||||
# Create app
|
# Create app
|
||||||
app = AnnotationApp(args.images_dir, args.model_weights)
|
app = AnnotationApp(args.images_dir, args.model_weights)
|
||||||
|
|
||||||
|
# Scan for available models on startup
|
||||||
|
app.scan_for_models(return_info=False)
|
||||||
|
|
||||||
# Create and launch UI
|
# Create and launch UI
|
||||||
demo = create_ui(app)
|
demo = create_ui(app)
|
||||||
|
|
||||||
@ -815,7 +1367,7 @@ def main():
|
|||||||
|
|
||||||
demo.launch(
|
demo.launch(
|
||||||
server_name="0.0.0.0",
|
server_name="0.0.0.0",
|
||||||
server_port=args.port,
|
server_port=7860,
|
||||||
share=False
|
share=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -5,10 +5,10 @@ DEFAULT_IMAGES_DIR = "IMAGE/" # Directory containing wood defect images
|
|||||||
DEFAULT_MODEL_WEIGHTS = "runs/yolox_training/training/weights/best.pt" # Trained YOLOX model
|
DEFAULT_MODEL_WEIGHTS = "runs/yolox_training/training/weights/best.pt" # Trained YOLOX model
|
||||||
|
|
||||||
# Training defaults
|
# Training defaults
|
||||||
DEFAULT_TRAIN_EPOCHS = 20
|
DEFAULT_TRAIN_EPOCHS = 1
|
||||||
DEFAULT_BATCH_SIZE = 4
|
DEFAULT_BATCH_SIZE = 4
|
||||||
DEFAULT_LEARNING_RATE = 1e-4
|
DEFAULT_LEARNING_RATE = 1e-4
|
||||||
DEFAULT_MODEL_SIZE = "small" # nano, small, medium, base
|
DEFAULT_MODEL_SIZE = "nano" # nano, small, medium, base
|
||||||
|
|
||||||
# Dataset split ratios
|
# Dataset split ratios
|
||||||
DEFAULT_TRAIN_SPLIT = 0.8
|
DEFAULT_TRAIN_SPLIT = 0.8
|
||||||
|
|||||||
51
fix_class_ids.py
Normal file
51
fix_class_ids.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Fix COCO annotation class IDs to be 0-based instead of 1-based.
|
||||||
|
RF-DETR expects classes to start from 0 (background), but our annotations use 1-10.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def fix_class_ids(coco_data):
|
||||||
|
"""Convert category IDs and annotation category_ids from 1-based to 0-based."""
|
||||||
|
# Fix categories: change id from 1-10 to 0-9
|
||||||
|
for cat in coco_data['categories']:
|
||||||
|
cat['id'] -= 1
|
||||||
|
|
||||||
|
# Fix annotations: change category_id from 1-10 to 0-9
|
||||||
|
for ann in coco_data['annotations']:
|
||||||
|
ann['category_id'] -= 1
|
||||||
|
|
||||||
|
return coco_data
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dataset_dir = Path('dataset_coco')
|
||||||
|
|
||||||
|
for split in ['train', 'valid', 'test']:
|
||||||
|
ann_file = dataset_dir / split / '_annotations.coco.json'
|
||||||
|
|
||||||
|
if not ann_file.exists():
|
||||||
|
print(f"Warning: {ann_file} not found, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Processing {ann_file}...")
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
with open(ann_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Fix class IDs
|
||||||
|
fixed_data = fix_class_ids(data)
|
||||||
|
|
||||||
|
# Save back
|
||||||
|
with open(ann_file, 'w') as f:
|
||||||
|
json.dump(fixed_data, f, indent=2)
|
||||||
|
|
||||||
|
print(f"✅ Fixed {len(data['annotations'])} annotations and {len(data['categories'])} categories")
|
||||||
|
|
||||||
|
print("🎉 All COCO annotations fixed for 0-based class indexing!")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
50
fix_coco_annotations.py
Normal file
50
fix_coco_annotations.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Fix COCO annotations for RF-DETR compatibility by adding missing 'supercategory' field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def fix_coco_annotations(coco_file: Path):
|
||||||
|
"""Add supercategory field to categories in COCO annotation file."""
|
||||||
|
print(f"Fixing {coco_file}...")
|
||||||
|
|
||||||
|
# Load the COCO data
|
||||||
|
with coco_file.open('r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Add supercategory field to all categories
|
||||||
|
for category in data['categories']:
|
||||||
|
if 'supercategory' not in category:
|
||||||
|
category['supercategory'] = 'wood_defect' # Default supercategory
|
||||||
|
|
||||||
|
# Save the fixed data
|
||||||
|
with coco_file.open('w') as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
print(f"✅ Fixed {len(data['categories'])} categories in {coco_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Fix all COCO annotation files."""
|
||||||
|
print("Fixing COCO annotations for RF-DETR compatibility...")
|
||||||
|
|
||||||
|
coco_files = [
|
||||||
|
Path("dataset_coco/train/_annotations.coco.json"),
|
||||||
|
Path("dataset_coco/valid/_annotations.coco.json"),
|
||||||
|
Path("dataset_coco/test/_annotations.coco.json")
|
||||||
|
]
|
||||||
|
|
||||||
|
for coco_file in coco_files:
|
||||||
|
if coco_file.exists():
|
||||||
|
fix_coco_annotations(coco_file)
|
||||||
|
else:
|
||||||
|
print(f"⚠️ {coco_file} not found")
|
||||||
|
|
||||||
|
print("\n🎉 All COCO annotations fixed for RF-DETR compatibility!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user