Giter VIP home page Giter VIP logo

Comments (2)

obilaniu avatar obilaniu commented on July 23, 2024

@EternityZY I am the author of the C AVX2-vectorized code.

It is first necessary to understand that in our work, causal networks are parametrized as stacked MLPs that compute a categorical distribution. This categorical distribution is either sampled from (sample_mlp()), or used to calculate log-probabilities and minimize cross-entropy (logprob_mlp).

The problem size, and corresponding MLP stack, is defined by the following hyperparameters:

  • M: Number of causal variables.
  • N: Array with number of categories in each variable. For example, eight Boolean variables would be [2,2,2,2,2,2,2,2].
  • Ns: sum(N), the sum of all variables' number of categories.
  • Nc: cumsum(N), the cumulative-sum array of number of categories, starting with zero. For eight boolean variables, would be [0,2,4,6,8,10,12,14]. Nc[-1] + N[-1] == Ns by definition.
  • Hgt/Hlr/H: The "embedding size" of categorical variables' discrete values in the ground-truth and learned models.

The MLP stack is parametered with the following arrays:

  • W0: Stack of "embedding" matrices. Shape (M,Ns,H).
  • B0: Stack of biases for hidden layer. Shape (M,H).
  • W1: Stack of weights for output layer. Shape (H,Ns).
  • B1: Stack of biases for output layer. Shape (Ns,).
  • config: Adjacency matrix. Shape (M,M).
  • The leaky-ReLU has alpha=0.1 and was chosen arbitrarily.

For each causal variable i from 0 to M-1, a single computation of the categorical distribution's parameters is roughly as follows:

def calclogits(i, sample, config)
    # LAYER 0
    h = self.B0[i, :]
    for j in parents(i, config):
        c = sample[j] # Categorical value of parent j of i. Integer, in range [0 .. N[j]-1]
        # W0 is pre-stacked along axis 1 because not all variables have the same number of categories.
        # So need to add Nc[j] in order to find the offset of variable j's categorical value c's *embedding* 
        offj = c + self.Nc[j]
        h += self.W0[i, offj, :]   # Sum into h the embedding corresponding to parent j's value c
    
    # LEAKY RELU
    h = leaky_relu(h, alpha=0.1)
    
    # LAYER 1
    # W1 and B1 are pre-stacked along axis 0 because not all variables have the same number of categories.
    # Slice out of the stack the weights for variable i and use them.
    offi = self.Nc[i]
    B1i  = self.B1[offi:offi+self.N[i], :] # Shape: (N[i],)
    W1i  = self.W1[offi:offi+self.N[i], :] # Shape: (N[i], H)
    outi = einsum('nh,h->n', W1i, h) + B1i # Shape: (N[i],)
    return logsoftmax(outi)  # Return shape: (N[i])

For sampling (sample_mlp()), with the normalized logits returned by calclogits(i, sample, config), we sample a categorical value for variable i. We assume that config is lower-triangular and iterate i from 0 to M-1. If config is lower-triangular, then necessarily all variables are in topologically-sorted order and iterating like this is causally-ordered.

for i in range(M):
    sample[i] = categorical(l=calclogits(i, sample, config))
return sample

For log-probability and backprop (logprob_mlp()), rather than sampling, we return one log-probability per variable, and neither need nor assume causal ordering:

logprobs = empty(M)
for i in range(M):
    l = calclogits(i, sample, config)
    logprobs[i] = l[sample[i]] # Because we did logsoftmax!
return logprobs

The actual C code is a lot more complicated because it is vectorized and batched. I wrote it because it's extremely slow to run this in PyTorch.

from causal_learning_unknown_interventions.

EternityZY avatar EternityZY commented on July 23, 2024

Thank you very much for your reply. You explained every detail of the question very well. Thank you for your help!

from causal_learning_unknown_interventions.

Related Issues (2)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.