Giter VIP home page Giter VIP logo

kevinsbello / dagma Goto Github PK

View Code? Open in Web Editor NEW
97.0 3.0 20.0 691 KB

A Python 3 package for learning Bayesian Networks (DAGs) from data. Official implementation of the paper "DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization"

Home Page: https://dagma.readthedocs.io/en/latest/

License: Apache License 2.0

Python 100.00%
structure-learning causal-discovery graphical-models neurips-2022 bayesian-networks dags

dagma's Introduction

DAGMA

The dagma library is a Python 3 package for learning DAGs (a.k.a. Bayesian networks) from data.

DAGMA works by optimizing a given score/loss function, where the structure that relates the variables is constrained to be a directed acyclic graph (DAG). Due to the super-exponential number of DAGs w.r.t. the number of variables, the vanilla formulation results in a hard combinatorial optimization problem. DAGMA reformulates this optimization problem, by replacing the combinatorial constraint with a non-convex differentiable function that exactly characterizes DAGs, thus, making the optimization amenable to continuous optimization methods such as gradient descent.

Citation

This is an implementation of the following paper:

[1] Bello K., Aragam B., Ravikumar P. (2022). DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization. NeurIPS'22.

If you find this code useful, please consider citing:

BibTeX

@inproceedings{bello2022dagma,
    author = {Bello, Kevin and Aragam, Bryon and Ravikumar, Pradeep},
    booktitle = {Advances in Neural Information Processing Systems},
    title = {{DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization}},
    year = {2022}
}

Features

  • Supports continuous data for linear (see dagma.linear) and nonlinear models (see dagma.nonlinear).
  • Supports binary (0/1) data for generalized linear models, via dagma.linear.DagmaLinear and using logistic as score.
  • Faster than other continuous optimization methods for structure learning, e.g., NOTEARS, GOLEM.

Getting Started

Install the package

We recommend using a virtual environment via virtualenv or conda, and use pip to install the dagma package.

$ pip install dagma

Using dagma

See an example on how to use dagma in this iPython notebook.

An Overview of DAGMA

We propose a new acyclicity characterization of DAGs via a log-det function for learning DAGs from observational data. Similar to previously proposed acyclicity functions (e.g. NOTEARS), our characterization is also exact and differentiable. However, when compared to existing characterizations, our log-det function: (1) Is better at detecting large cycles; (2) Has better-behaved gradients; and (3) Its runtime is in practice about an order of magnitude faster. These advantages of our log-det formulation, together with a path-following scheme, lead to significant improvements in structure accuracy (e.g. SHD).

The log-det acyclicity characterization

Let $W \in \mathbb{R}^{d\times d}$ be a weighted adjacency matrix of a graph of $d$ nodes, the log-det function takes the following form:

$$h^{s}(W) = -\log \det (sI-W\circ W) + d \log s,$$

where $I$ is the identity matrix, $s$ is a given scalar (e.g., 1), and $\circ$ denotes the element-wise Hadamard product. Of particular interest, we have that $h(W) = 0$ if and only if $W$ represents a DAG, and when the domain of $h$ is the set of M-matrices then $h$ is well-defined and non-negative. For more properties of $h(W)$ (e.g., being an invex function), $\nabla h(W)$, and $\nabla^2 h(W)$, we invite you to look at [1].

A path-following approach

Given the exact differentiable characterization of a DAG, we are interested in solving the following optimization problem:

$$\begin{array}{cl} \min _{W \in \mathbb{R}^{d \times d}} & Q(W;\mathbf{X}) \\\ \text { subject to } & h^{s}(W) = 0, \end{array}$$

where $Q$ is a given score function (e.g., square loss) that depends on $W$ and the dataset $\mathbf{X}$. To solve the above constrained problem, we propose a path-following approach where we solve a few of the following unconstrained problems:

$$\hat{W}^{(t+1)} = \arg\min_{W}\; \mu^{(t)} Q(W;\mathbf{X}) + h(W),$$

where $\mu^{(t)} \to 0$ as $t$ increases. Leveraging the properties of $h$, we show that, at the limit, the solution is a DAG. The trick to make this work is to use the previous solution as a starting point when solving the current unconstrained problem, as usually done in interior-point algorithms. Finally, we use a simple accelerated gradient descent method to solve each unconstrained problem.

Let us give an illustration of how DAGMA works in a two-node graph (see Figure 1 in [1] for more details). Here $w_1$ (the x-axis) represents the edge weight from node 1 to node 2; while $w_2$ (y-axis) represents the edge weight from node 2 to node 1. Moreover, in this example, the ground-truth DAG corresponds to $w_1 = 1.2$ and $w_2 = 0$.

Below we have 4 plots, where each illustrates the solution to an unconstrained problem for different values of $\mu$. In the top-left plot, we have $\mu=1$, and we solve the unconstrained problem starting at the empty graph (i.e., $w_1 = w_2 = 0$), which corresponds to the red point, and after running gradient descent, we arrive at the cyan point (i.e., $w_1 = 1.06, w_2 = 0.24$). Then, for the next unconstrained problem in the top-right plot, we have $\mu = 0.1$, and we initialize gradient descent at the previous solution, i.e., $w_1 = 1.06, w_2 = 0.24$, and arrive at the cyan point $w_1 = 1.16, w_2 = 0.04$. Similarly, DAGMA solves for $\mu=0.01$ and $\mu=0.001$, and we can observe how the solution at the final iteration (bottom-right plot) is close to the ground-truth DAG $w_1 = 1.2, w_2 = 0$.

dagma_4iters

Requirements

  • Python 3.7+
  • numpy
  • scipy
  • igraph
  • tqdm
  • torch: Only used for nonlinear models.

Contents

  • linear.py - implementation of DAGMA for linear models with l1 regularization (supports L2 and Logistic losses).
  • nonlinear.py - implementation of DAGMA for nonlinear models using MLP
  • locally_connected.py - special layer structure used for MLP
  • utils.py - graph simulation, data simulation, and accuracy evaluation

Acknowledgments

We thank the authors of the NOTEARS repo for making their code available. Part of our code is based on their implementation, specially the utils.py file and some code from their implementation of nonlinear models.

dagma's People

Contributors

kevinsbello avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

dagma's Issues

Selection of T

Hi there, thank you for this efficient algorithm!

I applied my real-world dataset to both linear and nonlinear models. The size of the dataset is about 0.4 m rows and 23 columns. I have the following questions about the selection of T regarding linear and nonlinear models.

  1. For linear models, I tried to set a T larger than the default. Then I got an early stop. Can I consider this one the final result? Because I found that this matrix is slightly different from the one from the algorithm with a smaller T.

  2. For the nonlinear model, I found that the values in the matrix are quite small for T=4, the scale was about 1e-5. When I increased T, they were getting smaller, of which the scale could be 1e-19 as T=15. Does that mean a nonlinear model is not proper for this dataset, indicating the values tend to converge to 0? Note that I managed to adapt the nonlinear algorithm on GPU for a shorter running time. So, I slightly modified the codes to make it. This modification does not change the main processes of the algorithm.

Kind Regards,
Weikang

Implications of applying DAGMA on mixed data

Hi Kevin, thanks so much for providing the DAGMA algorithm, it runs fast and is more scalable compared to other algorithms I have tried.

I understand that the current algorithm works on either continuous and binary data, but given that a large portion of real-world data consists of a mix of continuous, binary and categorical, can I run the DAGMA linear algorithm on mixed data? I have tried it, and the results seems okay, but I do not know if it is the theoretical right way to do it. Another solution I have considered is to run DAGMA on purely continuous data, identify the relationships then reduce the number of non-essential continuous columns, then run another algorithm that can handle mixed data but is not as scalable (such as DECI) to get the causal graph. Would like to hear your thoughts about this, thanks in advance!

Queries on pre-processing, variable types & overall performance

Hello! I've just had a look at the paper on DAGMA and it was really interesting in how you rescoped the typical continuous optimization-based approach (from NOTEARS) into something that leverages on the properties of DAGs!

However, I'm no expert in Causal Discovery (nor am I proficient in optimization methods as I just barely managed to understand the high-level intuition of DAGMA), and thus am writing this post to clarify certain doubts that I have with regards to the practical applications/implementations of DAGMA:

(1) Similar to NOTEARS (based on this paper - Unsuitability of NOTEARS for Causal Graph Discovery), is DAGMA susceptible to rescaling of data and is therefore not scale invariant? Such that standardizing all variables to unit variance as a preprocessing step is necessary?

(2) What are the various variable types (i.e., continuous, discrete & categorical) that DAGMA can take in at once? Or can DAGMA (at least for the linear models) only take in variables with only the same type at any one point?

(3) I understand that you have compared the performance & scalability of DAGMA to other continuous optimization-based approaches such as NOTEARS & GOLEM. If possible, where do you think it might fit in the framework of such approaches that are known to be more 'scalable' to larger amounts of nodes such as NOBEARS & LEAST? I am just curious about both the scalability of DAGMA to such algorithms, where there's a nice summary on continuous optimization-based approaches from a recent paper early last year (image attached below, from D’ya Like DAGs? A Survey on Structure Learning and Causal Discovery)

image

I understand that answering all of these questions might be more than a mouthful, but I am trying to clarify all of these doubts as I'm really interested in seeing if there could be a practical application of DAGMA for my real world dataset with ~300-700 variables (columns) & ~1m rows (as my current implementation using DirectLiNGAM is not really scalable at all and unfortunately is constrained to only continuous variables).

Thanks so much in advance!

Hyper parameters for large scale experiments

Hi there, would you please provide the hyper parameters for large scale experiments?
I tries to use the hyper parameter for small scale exps, but it did not work.

Best,

Zhen Zhang

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.