Giter VIP home page Giter VIP logo

basketball-action-recognition's Introduction

Basketball-Action-Recognition

Spatio-Temporal Classification of ๐Ÿ€ Basketball Actions using 3D-CNN Models trained on the SpaceJam Dataset.

LeBron shooting over Deandre Jordan Lebron Shoots

Live Example

Motivation

Utilizing the SpaceJam Basketball Action Dataset Repo, I aim to create a model that takes a video of a basketball game to classify a given action for each of the players tracked with a bounding box. There are two essential parts for this program: R(2+1)D Model (Can be any 3D CNN architecture) and the player tracking. The deep learning framework used to train the network was PyTorch and the machine used to train the model was the Nvidia RTX 3060ti GPU.

This is a demo video from the SpaceJam Repo.

Demo Video

Action/Video Classification

A pretrained baseline R(2+1)D CNN (pretrained on kinetics-400 dataset) from torchvision.models is used and further fine-tuned on the SpaceJam dataset. Any 3D CNN architecture can be used, but for this project it was decided that the R(2+1)D was a perfect balance in terms of number of parameters and overall model performance. It was also shown in the paper that factorizing 3D convolutional filters into separate spatial and temporal dimensions, alongside residual learning yields significant gains in accuracy. The training was done at train.py.

Dataset

As mentioned above, the SpaceJam Basketball Action Dataset was used to train the R(2+1)D CNN model for video/action classification of basketball actions. The Repo contains two datasets (clips->.mp4 files and joints -> .npy files) of basketball single-player actions. The size of the two final annotated datasets is about 32,560 examples. Custom dataloaders were used for the basketball dataset in the dataset.py.

alt text

Augmentations

After reading the thesis Classificazione di Azioni Cestistiche mediante Tecniche di Deep Learning, (Written by Simone Francia) it was determined that the poorest classes with examples less than 2000 examples were augmented. Dribble, Ball in Hand, Pass, Block, Pick and Shoot were among the classes that were augmented. Augmentations were applied by running the script augment_videos.py and saved in a given output directory. Translation and Rotation were the only augmentations applied. After applying the augmentations the dataset has 49,901 examples.

Rotate

rotate

Translate

translate

Training

The training was done at train.py. The training was run for 25 epochs and with a batch size of 8. The model was trained with the classic 70/20/10 split. Where 70% of the data was use to train and 20% was used to validate the model. And, the rest of the 10% was used in the inference to test the final model. It was found that a learning rate of 0.0001 was better than a learning rate of 0.001.

Checkpointing

Both history and checkpointing is done after every epoch with checkpoints.py in the utils directory.

Validation and Evaluation

The final model was a R(2+1)D CNN trained on the additional augmented examples. For validation on the test set, the model at epoch 19 was used as it was the best performing model in terms of validation f1-score and accuracy. The model performs significantly better than the reported 73% in the thesis Classificazione di Azioni Cestistiche mediante Tecniche di Deep Learning, acheiving 85% for both validation accuracy and test accuracy. The confusion matrix was attained using the inference.py code. Further analysis on predictions and errors is done on error_analysis.ipynb notebook.

Confusion Matrix

- 0: Block, 1: Pass, 2: Run, 3: Dribble, 4: Shoot, 5: Ball in Hand, 6: Defence, 7: Pick, 8: No Action, 9: Walk

Test on Training and Validation Set

training and validation

Testing on the 10% of the leftover data.

confusion matrix

Inference Examples - Error Analysis

State Shooting Dribble Pass Defence Pick Run Walk Block No Action
True true_shoot true_dribbble true_pass true_defence true_pick true_run true_walk true_block true_no_action
False false_shoot false_dribbble false_pass false_defence false_pick false_run false_walk false_block false_no_action

Player Tracking

All player tracking is done in main.py. Players are tracked by manually selecting the ROI using the opencv TrackerCSRT_create() tracker. In theory, an unlimited amount of people or players can be tracked, but this will significantly increase the compute time. In the example above only 2 players, LeBron James (Offence) & Deandre Jordan (Defence) were tracked. A simple example of player tracking is available in Basketball-Player-Tracker.

Output

After extracting the bounding boxes from TrackerCSRT_create(), a cropped clip of 16 frames is used to classify the actions. The 16 frame length clip is determined by the vid_stride (Set to 8 in the example video above) which is set in the cropWindows() function in main.py. Within the cropped window time frame the action is displayed on top of the bounding boxes to show the action of the tracked player.

Future Additions

  • Separate augmented examples from validation and only in training.
  • Utilize better player tracking methods.
  • Restrict Box size to 176x128 (Or with similar Aspect Ratio), so resize of image is not applied.
  • Fully automate player tracking. Potentially using YOLO or any other Object Detection Models.
  • Play around with hyperparameters such as learning rates, batch size, layers frozen, etc.
  • Try various 3D-CNN Architectures or sequential models such as CONV-LSTMs.
  • Improve model to +90% accuracy and f1-score.

Note:

  • The Model does not perform well with off-ball actions for some reason. Often times, the defender is classified to be dribbling when they are not. This might be because of the similarity of the stance while dribbling the ball and playing defence. For both movements, players generally lower their torsos forward in order to lower their centre of gravity.

Credits

Major thanks to Simone Francia for the basketball action dataset and paper on action classification with 3D-CNNs.

basketball-action-recognition's People

Contributors

hkair avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

basketball-action-recognition's Issues

error: (-215:Assertion failed) type == CV_32FC1 || type == CV_32FC2 || type == CV_64FC1 || type == CV_64FC2 in function 'cv::dft'

Dear author, after downloading yolov3.cfg & weights, and put them into main folder, and put lebron_shoots.mp4 into videos folder, I met the following problem in 'main.py'


4.5.1
Press q to quit selecting boxes and start tracking
Press any other key to select next object
113
Selected bounding boxes [(0, 0, 0, 0)]
Traceback (most recent call last):

File "D:\Basketball-Action-Recognition\main.py", line 547, in
main()

File "D:\Basketball-Action-Recognition\main.py", line 494, in main
videoFrames, playerBoxes, Width, Height, colors = extractFrame(args.videoPath)

File "D:\Basketball-Action-Recognition\main.py", line 180, in extractFrame
trackers.add(createTrackerByName(args.tracker), frame, bbox)

error: OpenCV(4.5.1) C:\Users\appveyor\AppData\Local\Temp\1\pip-req-build-r2ue8w6k\opencv\modules\core\src\dxt.cpp:3506: error: (-215:Assertion failed) type == CV_32FC1 || type == CV_32FC2 || type == CV_64FC1 || type == CV_64FC2 in function 'cv::dft'

Problems I met & record for solution

Belowing is some problems I met and trying to solve to help me run the code

Dear author,
I download your repo and find always bugs, and don't know how to start at first, bellowing is the record of my attempts

  1. I download yolo3.cfg & yolo3.weights in main folder
  2. I download the dataset.zip in https://github.com/simonefrancia/SpaceJam and unzip it into main folder.
  3. I run the main.py, and selecting a rectangle and then it told me '[Errno 2] No such file or directory: 'model_checkpoints/r2plus1d_augmented-2/r2plus1d_multiclass_19_0.0001.pt''
  4. and I think maybe I should run train.py at first to get the above 'r2plus1d_multiclass', but it told me [Errno 2] No such file or directory: 'dataset/augmented_annotation_dict.json'
  5. and I think maybe I should run augment_video.py at first to get 'augmented_annotation_dict.json', but it told me
    Traceback (most recent call last):

File "D:\Basketball-Action-Recognition\augment_videos.py", line 128, in
augmentVideo(annotation_dict, labels_dict)

File "D:\Basketball-Action-Recognition\augment_videos.py", line 23, in augmentVideo
labels_dict = json.load(f, object_hook=keystoint)

File "D:\anaconda3\envs\pymarl\lib\json_init.py", line 296, in load_
parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, **kw)

File "D:\anaconda3\envs\pymarl\lib\json_init.py", line 361, in loads_
return cls(**kw).decode(s)

File "D:\anaconda3\envs\pymarl\lib\json\decoder.py", line 337, in decode
_obj, end = self.raw_decode(s, idx=w(s, 0).end())

File "D:\anaconda3\envs\pymarl\lib\json\decoder.py", line 353, in raw_decode
obj, end = self.scan_once(s, idx)

JSONDecodeError: Expecting property name enclosed in double quotes

after above,

  1. I modify the dataset\labels_dict.json to make sure the key is double quotes and I get the 'augmented_annotation_dict.json'
  2. Then I run the train.py, it seems will train abt 1 hour, I ll update then
    it seems 1 epoch 1 hour, 25 epoch... seems I ll update tomorrow...
    after 1 epoch, I have 'model_checkpoints/r2plus1d_augmented-2/r2plus1d_multiclass_1_0.0001.pt', seems different from the above 'model_checkpoints/r2plus1d_augmented-2/r2plus1d_multiclass_19_0.0001.pt', maybe I should wait it 19 epoch
    โ˜…attention, you should reserve 341M*25 โ‰ˆ 8.3G space to save them
  3. now I have r2plus1d_multiclass_1_0.0001.pt~r2plus1d_multiclass_1_0.0024.pt, and it told me
    Training complete in 1466m 17s
    Best val Acc: 0.854709
    Best Validation Loss: 0.45085233312969436 Epoch: 3
    Best Training Loss: 0.022664556910204244 Epoch: 23
    Traceback (most recent call last):
    File "D:\Basketball-Action-Recognition\train.py", line 343, in
    plot_epoch

File "D:\Basketball-Action-Recognition\utils\checkpoints.py", line 88, in plot_curves
plt.plot(epochs, train_acc, label='train accuracy')

File "D:\anaconda3\envs\pymarl\lib\site-packages\matplotlib\pyplot.py", line 2759, in plot
**({"data": data} if data is not None else {}), **kwargs)

_ File "D:\anaconda3\envs\pymarl\lib\site-packages\matplotlib\axes_axes.py", line 1632, in plot_
_lines = [*self.get_lines(*args, data=data, **kwargs)]

_File "D:\anaconda3\envs\pymarl\lib\site-packages\matplotlib\axes_base.py", line 312, in call
_yield from self.plot_args(this, kwargs)

_ File "D:\anaconda3\envs\pymarl\lib\site-packages\matplotlib\axes_base.py", line 488, in plot_args
_y = check_1d(xy[1])

File "D:\anaconda3\envs\pymarl\lib\site-packages\matplotlib\cbook_init.py", line 1304, in check_1d
return np.atleast_1d(x)

File "<array_function internals>", line 6, in atleast_1d

File "D:\anaconda3\envs\pymarl\lib\site-packages\numpy\core\shape_base.py", line 65, in atleast_1d
ary = asanyarray(ary)

File "D:\anaconda3\envs\pymarl\lib\site-packages\torch_tensor.py", line 678, in array
return self.numpy()

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

  1. I create 'output_videos' folder in main folder, then I run main.py and select some rectangle, I get an output video as a result
  2. โ˜…so now if I want to have everyone's action recognition, I need to rectangle them one by one

Failed to read video

Dear author, after downloading yolov3.cfg & weights, and put them into main folder, I met the following problem in 'main.py'


Reloaded modules: utils.checkpoints
4.6.0
Failed to read video
An exception has occurred, use %tb to see the full traceback.

SystemExit: 1

yolov3.cfg & yolov3.weight

Dear author, after I download [yolov3.cfg(https://github.com/qqwweee/keras-yolo3) & yolov3.weights(https://pjreddie.com/darknet/yolo/)], where should I put them in?

I tried the main folder, but still not work...I am totally beginner in this area, pls help

at first
โ†“

File "D:\Basketball-Action-Recognition\main.py", line 137, in extractFrame
net = cv2.dnn.readNet(args.weights, args.config)

error: OpenCV(4.6.0) D:\a\opencv-python\opencv-python\opencv\modules\dnn\src\darknet\darknet_importer.cpp:210: error: (-212:Parsing error) Failed to open NetParameter file: yolov3.cfg in function 'cv::dnn::dnn4_v20220524::readNetFromDarknet'

after download yolov3.cfg & yolov3.weights and put them in main folder
โ†“

File "D:\Basketball-Action-Recognition\main.py", line 137, in extractFrame
net = cv2.dnn.readNet(args.weights, args.config)

error: OpenCV(4.6.0) D:\a\opencv-python\opencv-python\opencv\modules\dnn\src\darknet\darknet_io.cpp:660: error: (-215:Assertion failed) separator_index < line.size() in function 'cv::dnn::darknet::ReadDarknetFromCfgStream'

above fixed
โ†“
but new bug

4.6.0
Failed to read video
An exception has occurred, use %tb to see the full traceback.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.