Comments (4)
There is an argument single_cls
in the model.train()
function that focuses the model towards detection instead of classification. You can set it to True
. This will cause model to only detect the trash in images.
Keep in mind that this will remove distinction between any classes you have in your dataset.
Also, you can limit the scope of detection while using your existing model for inference by using classes
argument in model.predict()
. Obtain the class names using model.names
in separate line, and look for the index of trash
(say 0), then pass it into your inference code as model.predict(image, classes=[0])
.
You can look at further arguments of model.train()
here
from ultralytics.
@pauldeveaux hello,
Thank you for your prompt response and for providing additional details. It's great to hear that using single_cls=True
has improved your results. However, I understand the confusion regarding the model.names
output.
The issue you're encountering is likely due to the model still retaining the class names from the pre-trained weights. When you fine-tune a model with a custom dataset, especially with a single class, it's essential to ensure that the model's class names are updated accordingly.
Here are a few steps to address this:
-
Ensure Data Configuration: Double-check your
data.yaml
file to ensure it correctly specifies the single class. It looks like yourdata.yaml
is correctly set up, but just to confirm:names: - trash nc: 1 train: /content/Trash-Detector-2/train val: /content/Trash-Detector-2/valid test: /content/Trash-Detector-2/test/images
-
Load Model with Custom Data: When loading the model for training, ensure it reads the custom data configuration. This can be done by specifying the
data
argument correctly:results = model.train( data='/content/Trash-Detector-2/data.yaml', epochs=150, imgsz=640, batch=-1, workers=16, patience=50, single_cls=True, resume=os.path.exists(weights_path) )
-
Update Class Names: After training, you can manually update the
model.names
to reflect your custom class:model.names = {0: 'trash'}
-
Check Model Names After Training: Ensure that the model's class names are updated after training:
print(model.names) # Should output {0: 'trash'}
Here is a revised version of your training script with these considerations:
model = YOLO("/content/yolov8n.pt")
# Check if GPU is available
if torch.cuda.is_available():
print("GPU is available")
else:
print("GPU is not available")
weights_path = "/content/runs/detect/train2/weights/last.pt"
# Start or resume model
if os.path.exists(weights_path) and not RESTART:
print("Resuming training from saved weights...")
model = YOLO(weights_path)
else:
print("Starting new training...")
model = YOLO("/content/yolov8n.pt")
def save_on_drive(trainer):
print(trainer.__dir__())
epoch = trainer.epoch
save_dir = "/content/drive/MyDrive/Model"
save_path = os.path.join(save_dir, f'yolo_model_{epoch+1}.pt')
model.save(save_path)
# wandb logs
add_wandb_callback(model, enable_model_checkpointing=True)
model.add_callback("on_train_epoch_end", save_on_drive)
results = model.train(
data='/content/Trash-Detector-2/data.yaml',
epochs=150,
imgsz=640,
batch=-1,
workers=16,
patience=50,
single_cls=True,
resume=os.path.exists(weights_path)
)
# Update model names to reflect the custom class
model.names = {0: 'trash'}
# Validate the model
model.val()
# Finish the wandb run
wandb.finish()
By following these steps, you should be able to ensure that the model correctly recognizes and outputs the 'trash' class. If you continue to experience issues, please let us know, and we can further investigate.
Thank you for your patience and dedication to improving your model. Happy training! π
from ultralytics.
π Hello @pauldeveaux, thank you for your interest in Ultralytics YOLOv8 π! We recommend a visit to the Docs for new users where you can find many Python and CLI usage examples and where many of the most common questions may already be answered.
If this is a π Bug Report, please provide a minimum reproducible example to help us debug it.
If this is a custom training β Question, please provide as much information as possible, including dataset image examples and training logs, and verify you are following our Tips for Best Training Results.
Join the vibrant Ultralytics Discord π§ community for real-time conversations and collaborations. This platform offers a perfect space to inquire, showcase your work, and connect with fellow Ultralytics users.
Install
Pip install the ultralytics
package including all requirements in a Python>=3.8 environment with PyTorch>=1.8.
pip install ultralytics
Environments
YOLOv8 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):
- Notebooks with free GPU:
- Google Cloud Deep Learning VM. See GCP Quickstart Guide
- Amazon Deep Learning AMI. See AWS Quickstart Guide
- Docker Image. See Docker Quickstart Guide
Status
If this badge is green, all Ultralytics CI tests are currently passing. CI tests verify correct operation of all YOLOv8 Modes and Tasks on macOS, Windows, and Ubuntu every 24 hours and on every commit.
from ultralytics.
Hello,
Thank you for your response. I did as you said (add the single_cls=True as argument in model.train() ). I had better results at the end of the training.
results = model.train(
data=os.path.join(dataset.location, "/content/Trash-Detector-2/data.yaml"),
epochs=150,
imgsz=640,
batch=-1,
workers=16,
patience=50,
single_cls=True,
resume=os.path.exists(weights_path)
)
However, when I print model.names, the "trash" class does not appear in the output:
{0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
Could you please help me understand why the class is missing and how I can ensure it is included in the model?
Thank you for your assistance!
from ultralytics.
Related Issues (20)
- Changing the C2f block fixes the pruning but how can I make it work with its own architecture? HOT 1
- Picking instance segmentation in roboflow for yolov8-obb HOT 5
- Yolo v10 is slower than v8? HOT 9
- Error message when export tensorrt in Jetpack 4 docker container. HOT 4
- yolov8 with multi cameras (using only CPU) HOT 5
- GPU memory usage issue
- how can I predict when my ch >4 HOT 1
- what's the meaning of (40 CPUs, 502.2 GB RAM, 15.6/18.3 GB disk)? HOT 1
- Can not export yolov10 model to paddlepaddle HOT 2
- Yolov8 loads other datasets HOT 1
- When converting an ONNX model to an OM model on the Orange Pi, an input_shape error occurs HOT 4
- Training Abnormality HOT 5
- Can I convolve in different ways for different epochs HOT 1
- Pytorch install in jetson tx2 HOT 4
- YOLOv8 export TensorRt INT8 format βdynamic axes will be enabled by default when exporting with int8=True even when not explicitly setβ HOT 4
- error of YOLOv8-P2-OBB HOT 3
- Visualize data augmentation HOT 9
- Yolo-world training from scratch HOT 6
- The model converted to coreml format always shows confidence 1.0 HOT 5
- Error: βNoneTypeβ object is not callable during YOLOv8 Classification Training with Multi-GPU HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ultralytics.