Comments (2)
@EternityZY I am the author of the C AVX2-vectorized code.
It is first necessary to understand that in our work, causal networks are parametrized as stacked MLPs that compute a categorical distribution. This categorical distribution is either sampled from (sample_mlp()
), or used to calculate log-probabilities and minimize cross-entropy (logprob_mlp
).
The problem size, and corresponding MLP stack, is defined by the following hyperparameters:
M
: Number of causal variables.N
: Array with number of categories in each variable. For example, eight Boolean variables would be[2,2,2,2,2,2,2,2]
.Ns
:sum(N)
, the sum of all variables' number of categories.Nc
:cumsum(N)
, the cumulative-sum array of number of categories, starting with zero. For eight boolean variables, would be[0,2,4,6,8,10,12,14]
.Nc[-1] + N[-1] == Ns
by definition.Hgt
/Hlr
/H
: The "embedding size" of categorical variables' discrete values in the ground-truth and learned models.
The MLP stack is parametered with the following arrays:
W0
: Stack of "embedding" matrices. Shape(M,Ns,H)
.B0
: Stack of biases for hidden layer. Shape(M,H)
.W1
: Stack of weights for output layer. Shape(H,Ns)
.B1
: Stack of biases for output layer. Shape(Ns,)
.config
: Adjacency matrix. Shape(M,M)
.- The leaky-ReLU has
alpha=0.1
and was chosen arbitrarily.
For each causal variable i
from 0 to M-1
, a single computation of the categorical distribution's parameters is roughly as follows:
def calclogits(i, sample, config)
# LAYER 0
h = self.B0[i, :]
for j in parents(i, config):
c = sample[j] # Categorical value of parent j of i. Integer, in range [0 .. N[j]-1]
# W0 is pre-stacked along axis 1 because not all variables have the same number of categories.
# So need to add Nc[j] in order to find the offset of variable j's categorical value c's *embedding*
offj = c + self.Nc[j]
h += self.W0[i, offj, :] # Sum into h the embedding corresponding to parent j's value c
# LEAKY RELU
h = leaky_relu(h, alpha=0.1)
# LAYER 1
# W1 and B1 are pre-stacked along axis 0 because not all variables have the same number of categories.
# Slice out of the stack the weights for variable i and use them.
offi = self.Nc[i]
B1i = self.B1[offi:offi+self.N[i], :] # Shape: (N[i],)
W1i = self.W1[offi:offi+self.N[i], :] # Shape: (N[i], H)
outi = einsum('nh,h->n', W1i, h) + B1i # Shape: (N[i],)
return logsoftmax(outi) # Return shape: (N[i])
For sampling (sample_mlp()
), with the normalized logits returned by calclogits(i, sample, config)
, we sample a categorical value for variable i
. We assume that config
is lower-triangular and iterate i
from 0
to M-1
. If config
is lower-triangular, then necessarily all variables are in topologically-sorted order and iterating like this is causally-ordered.
for i in range(M):
sample[i] = categorical(l=calclogits(i, sample, config))
return sample
For log-probability and backprop (logprob_mlp()
), rather than sampling, we return one log-probability per variable, and neither need nor assume causal ordering:
logprobs = empty(M)
for i in range(M):
l = calclogits(i, sample, config)
logprobs[i] = l[sample[i]] # Because we did logsoftmax!
return logprobs
The actual C code is a lot more complicated because it is vectorized and batched. I wrote it because it's extremely slow to run this in PyTorch.
from causal_learning_unknown_interventions.
Thank you very much for your reply. You explained every detail of the question very well. Thank you for your help!
from causal_learning_unknown_interventions.
Related Issues (2)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from causal_learning_unknown_interventions.