A PyTorch implementation of Expert Concept-Based Learning, described in the paper.
Assume that not only a single target
Let a rule
The key idea is to make a neural network layer, that can estimate probability distributions of concepts that satisfy the given rules.
We estimate marginal probabilities of concepts corresponding to the
joint probability distrituion of all the concepts under constraint
that the given rules are satisfied:
Lets consider concept vector
Then the implementation of the layer will be:
# we use SymPy to define rules
from sympy import Symbol, Eq, Implies, Equivalent
from ecbl import (
ConceptsHeadWrapper, # basic optimizations
AdmissibleStatesHead,
ConstraintsHead,
)
# specify concept cardinalities (number of values):
concepts = {'y_0': 2, 'y_1': 2, 'y_2': 3, 'y_3': 3}
# define a rule set
rules = [
# (y_0 == 1) => ((y_2 == 1) & (y_1 == 1))
Implies(
Eq(Symbol('y_0'), 1),
Eq(Symbol('y_2'), 1) & Eq(Symbol('y_1'), 1)
),
]
concepts_head = ConceptsHeadWrapper(
in_features=n_in_features, # neural network embedding size
concepts=concepts, # concept cardinalities
rules=rules, # rules
head_cls=AdmissibleStatesHead, # core concept-layer class
)
model = torch.nn.Sequential(
nn_encoder, # some neural network encoder that maps X to an embedding
concepts_head,
)
preds = model(X)
# preds["<concept name>"] is a probability distribution of the concept
assert preds['y_0'].shape[1] = 2
assert preds['y_3'].shape[1] = 3
# the neural network can be optimized through `concepts_head`
More detailed examples can be found in notebooks.
The package is under development, and can be installed from this git repository:
pip install git+https://github.com/andruekonst/ecbl.git
Or clone the repo and install in development mode:
git clone https://github.com/andruekonst/ecbl.git
cd ecbl
pip install -e .
All the methods are presented in the following preprint:
@article{konstantinov2024incorporating,
title={Incorporating Expert Rules into Neural Networks in the Framework of Concept-Based Learning},
author={Konstantinov, Andrei V and Utkin, Lev V},
journal={arXiv preprint arXiv:2402.14726},
year={2024}
}