Giter VIP home page Giter VIP logo

Comments (4)

RS-v620 avatar RS-v620 commented on July 22, 2024 1

There is an argument single_cls in the model.train() function that focuses the model towards detection instead of classification. You can set it to True. This will cause model to only detect the trash in images.
Keep in mind that this will remove distinction between any classes you have in your dataset.

Also, you can limit the scope of detection while using your existing model for inference by using classes argument in model.predict(). Obtain the class names using model.names in separate line, and look for the index of trash(say 0), then pass it into your inference code as model.predict(image, classes=[0]).

You can look at further arguments of model.train() here

from ultralytics.

glenn-jocher avatar glenn-jocher commented on July 22, 2024 1

@pauldeveaux hello,

Thank you for your prompt response and for providing additional details. It's great to hear that using single_cls=True has improved your results. However, I understand the confusion regarding the model.names output.

The issue you're encountering is likely due to the model still retaining the class names from the pre-trained weights. When you fine-tune a model with a custom dataset, especially with a single class, it's essential to ensure that the model's class names are updated accordingly.

Here are a few steps to address this:

  1. Ensure Data Configuration: Double-check your data.yaml file to ensure it correctly specifies the single class. It looks like your data.yaml is correctly set up, but just to confirm:

    names:
    - trash
    nc: 1
    train: /content/Trash-Detector-2/train
    val: /content/Trash-Detector-2/valid
    test: /content/Trash-Detector-2/test/images
  2. Load Model with Custom Data: When loading the model for training, ensure it reads the custom data configuration. This can be done by specifying the data argument correctly:

    results = model.train(
        data='/content/Trash-Detector-2/data.yaml',
        epochs=150,
        imgsz=640,
        batch=-1,
        workers=16,
        patience=50,
        single_cls=True,
        resume=os.path.exists(weights_path)
    )
  3. Update Class Names: After training, you can manually update the model.names to reflect your custom class:

    model.names = {0: 'trash'}
  4. Check Model Names After Training: Ensure that the model's class names are updated after training:

    print(model.names)  # Should output {0: 'trash'}

Here is a revised version of your training script with these considerations:

model = YOLO("/content/yolov8n.pt")

# Check if GPU is available
if torch.cuda.is_available():
    print("GPU is available")
else:
    print("GPU is not available")

weights_path = "/content/runs/detect/train2/weights/last.pt"
# Start or resume model
if os.path.exists(weights_path) and not RESTART:
    print("Resuming training from saved weights...")
    model = YOLO(weights_path)
else:
    print("Starting new training...")
    model = YOLO("/content/yolov8n.pt")

def save_on_drive(trainer):
    print(trainer.__dir__())
    epoch = trainer.epoch
    save_dir = "/content/drive/MyDrive/Model"
    save_path = os.path.join(save_dir, f'yolo_model_{epoch+1}.pt')
    model.save(save_path)

# wandb logs
add_wandb_callback(model, enable_model_checkpointing=True)
model.add_callback("on_train_epoch_end", save_on_drive)

results = model.train(
    data='/content/Trash-Detector-2/data.yaml',
    epochs=150,
    imgsz=640,
    batch=-1,
    workers=16,
    patience=50,
    single_cls=True,
    resume=os.path.exists(weights_path)
)

# Update model names to reflect the custom class
model.names = {0: 'trash'}

# Validate the model
model.val()

# Finish the wandb run
wandb.finish()

By following these steps, you should be able to ensure that the model correctly recognizes and outputs the 'trash' class. If you continue to experience issues, please let us know, and we can further investigate.

Thank you for your patience and dedication to improving your model. Happy training! πŸš€

from ultralytics.

github-actions avatar github-actions commented on July 22, 2024

πŸ‘‹ Hello @pauldeveaux, thank you for your interest in Ultralytics YOLOv8 πŸš€! We recommend a visit to the Docs for new users where you can find many Python and CLI usage examples and where many of the most common questions may already be answered.

If this is a πŸ› Bug Report, please provide a minimum reproducible example to help us debug it.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset image examples and training logs, and verify you are following our Tips for Best Training Results.

Join the vibrant Ultralytics Discord 🎧 community for real-time conversations and collaborations. This platform offers a perfect space to inquire, showcase your work, and connect with fellow Ultralytics users.

Install

Pip install the ultralytics package including all requirements in a Python>=3.8 environment with PyTorch>=1.8.

pip install ultralytics

Environments

YOLOv8 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

Ultralytics CI

If this badge is green, all Ultralytics CI tests are currently passing. CI tests verify correct operation of all YOLOv8 Modes and Tasks on macOS, Windows, and Ubuntu every 24 hours and on every commit.

from ultralytics.

pauldeveaux avatar pauldeveaux commented on July 22, 2024

Hello,

Thank you for your response. I did as you said (add the single_cls=True as argument in model.train() ). I had better results at the end of the training.

results = model.train(
          data=os.path.join(dataset.location, "/content/Trash-Detector-2/data.yaml"),
          epochs=150,
          imgsz=640,
          batch=-1,
          workers=16,
          patience=50,
          single_cls=True,
          resume=os.path.exists(weights_path)
      )

However, when I print model.names, the "trash" class does not appear in the output:

{0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}

Could you please help me understand why the class is missing and how I can ensure it is included in the model?

Thank you for your assistance!

from ultralytics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.