Giter VIP home page Giter VIP logo

pseudoflow-parametric-cut's People

Contributors

dependabot[bot] avatar quic0 avatar robertoasin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pseudoflow-parametric-cut's Issues

Add igraph support

from ctypes import c_int, c_double, cast, byref, POINTER, cdll
from six.moves import xrange
import os

PATH = os.path.dirname(__file__)
libhpf = cdll.LoadLibrary(os.path.join(PATH, os.pardir, "libhpf.so"))


def _c_arr(c_type, size, init):
    x = c_type * size
    return x(*init)


def _get_arcmatrix(G, const_cap, mult_cap, source, sink):
    
    if 'networkx' in G.__module__:
        nNodes = G.number_of_nodes()
        nArcs = G.number_of_edges()
    
    elif 'igraph' in G.__module__:
        nNodes = G.vcount()
        nArcs = G.ecount()
    
    else:
        raise TypeError(
            "Graph should be networkx or igraph type. Please convert graph to one of those types."
        )

    nodeNames = []
    nodeDict = {}
    linearArcMatrix = []

    nodeDict[source] = 0
    nodeNames = [source]
    
    if 'networkx' in G.__module__:
        for node in G.nodes():
            if node not in {source, sink}:
                nodeDict[node] = len(nodeNames)
                nodeNames.append(node)
    
    elif 'igraph' in G.__module__:
        for node in G.vs:
            if node.index not in {source, sink}:
                nodeDict[node.index] = len(nodeNames)
                nodeNames.append(node.index)
    
    else:
        raise TypeError(
            "Graph should be networkx or igraph type. Please convert graph to one of those types."
        )

    nodeDict[sink] = len(nodeNames)
    nodeNames.append(sink)
    
    if 'networkx' in G.__module__:
        for fromNode, toNode, data in G.edges(data=True):
            linearArcMatrix += [
                nodeDict[fromNode],
                nodeDict[toNode],
                data[const_cap],
                data[mult_cap] if mult_cap else 0,
            ]
            
    elif 'igraph' in G.__module__:
        for e in G.es:
            linearArcMatrix += [
                nodeDict[e.source],
                nodeDict[e.target],
                e[const_cap],
                e[mult_cap] if mult_cap else 0,
            ]
    
    else:
        raise TypeError(
            "Graph should be networkx or igraph type. Please convert graph to one of those types."
        )

    return (nodeNames, nodeDict, map(lambda x: float(x), linearArcMatrix))


def _create_c_input(
    G, nodeDict, source, sink, arcMatrix, lambdaRange, roundNegativeCapacity
):
    if 'networkx' in G.__module__:
        nNodes = G.number_of_nodes()
        nArcs = G.number_of_edges()
    
    elif 'igraph' in G.__module__:
        nNodes = G.vcount()
        nArcs = G.ecount()
    
    else:
        raise TypeError(
            "Graph should be networkx or igraph type. Please convert graph to one of those types."
        )
        
    c_numNodes = c_int(nNodes)
    c_numArcs = c_int(nArcs)
    c_source = c_int(nodeDict[source])
    c_sink = c_int(nodeDict[sink])
    c_arcMatrix = _c_arr(c_double, nArcs * 4, arcMatrix)
    c_lambdaRange = _c_arr(c_double, 2, lambdaRange)
    if roundNegativeCapacity:
        c_roundNegativeCapacity = c_int(1)
    else:
        c_roundNegativeCapacity = c_int(0)

    return {
        "numNodes": c_numNodes,
        "numArcs": c_numArcs,
        "source": c_source,
        "sink": c_sink,
        "arcMatrix": c_arcMatrix,
        "lambdaRange": c_lambdaRange,
        "roundNegativeCapacity": c_roundNegativeCapacity,
    }


def _create_c_output():
    c_numBreakpoints = c_int(0)
    c_cuts = POINTER(c_int)()
    c_breakpoints = POINTER(c_double)()
    c_stats = _c_arr(c_int, 5, (0,) * 5)
    c_times = _c_arr(c_double, 3, (0.0,) * 3)

    return {
        "numBreakpoints": c_numBreakpoints,
        "cuts": c_cuts,
        "breakpoints": c_breakpoints,
        "stats": c_stats,
        "times": c_times,
    }


def _solve(c_input, c_output):

    hpf_solve = libhpf.hpf_solve
    hpf_solve.argtypes = [
        c_int,
        c_int,
        c_int,
        c_int,
        POINTER(c_double),
        c_double * 2,
        c_int,
        POINTER(c_int),
        POINTER(POINTER(c_int)),
        POINTER(POINTER(c_double)),
        c_int * 5,
        c_double * 3,
    ]

    hpf_solve(
        c_input["numNodes"],
        c_input["numArcs"],
        c_input["source"],
        c_input["sink"],
        cast(byref(c_input["arcMatrix"]), POINTER(c_double)),
        c_input["lambdaRange"],
        c_input["roundNegativeCapacity"],
        byref(c_output["numBreakpoints"]),
        byref(c_output["cuts"]),
        byref(c_output["breakpoints"]),
        c_output["stats"],
        c_output["times"],
    )


def _cleanup(c_output):
    libhpf.libfree(c_output["breakpoints"])
    libhpf.libfree(c_output["cuts"])


def _check_multipliers_sink_adjacent_negative(G, sink, mult_cap):
    if 'networkx' in G.__module__:
        for u in G.predecessors(sink):
            if G[u][sink][mult_cap] > 0:
                raise ValueError(
                    "Sink adjacent arcs should have non-positive multipliers. Arc (%s, %s = sink) has a multiplier of %f. Please reverse graph."
                    % (u, sink, G[u][sink][mult_cap])
                )
        
    elif 'igraph' in G.__module__:
        for e in G.es:
            if (e.target == sink) and (e['mult'] > 0):
                raise ValueError(
                    "Sink adjacent arcs should have non-positive multipliers. Arc (%s, %s = sink) has a multiplier of %f. Please reverse graph."
                    % (u, sink, G[u][sink][mult_cap])
                )
    
    else:
        raise TypeError(
            "Graph should be networkx or igraph type. Please convert graph to one of those types."
        )


def _check_multipliers_source_adjacent_positive(G, source, mult_cap):
    if 'networkx' in G.__module__:
        for v in G.successors(source):
            if G[source][v][mult_cap] < 0:
                raise ValueError(
                    "Source adjacent arcs should have non-negative multipliers. Arc (%s = source, %s) has a multiplier of %f. Please reverse graph."
                    % (source, v, G[source][v][mult_cap])
                )
        
    elif 'igraph' in G.__module__:
        for e in G.es:
            if (e.target == source) and (e['mult'] < 0):
                raise ValueError(
                    "Source adjacent arcs should have non-negative multipliers. Arc (%s = source, %s) has a multiplier of %f. Please reverse graph."
                    % (source, v, G[source][v][mult_cap])
                )
    
    else:
        raise TypeError(
            "Graph should be networkx or igraph type. Please convert graph to one of those types."
        )


def _read_output(c_output, nodeNames):
    numBreakpoints = c_output["numBreakpoints"].value
    breakpoints = [c_output["breakpoints"][i] for i in range(numBreakpoints)]

    cuts = {}
    for i, node in enumerate(nodeNames):
        cuts[node] = [
            c_output["cuts"][len(nodeNames) * j + i] for j in range(numBreakpoints)
        ]

    info = {
        "numArcScans": c_output["stats"][0],
        "numMergers": c_output["stats"][1],
        "numPushes": c_output["stats"][2],
        "numRelabels": c_output["stats"][3],
        "numGap": c_output["stats"][4],
        "readDataTime": c_output["times"][0],
        "intializationTime": c_output["times"][1],
        "solveTime": c_output["times"][2],
    }

    return breakpoints, cuts, info


def hpf(
    G,
    source,
    sink,
    const_cap,
    mult_cap=None,
    lambdaRange=None,
    roundNegativeCapacity=False,
):

    if mult_cap:
        parametric = True
        _check_multipliers_sink_adjacent_negative(G, sink, mult_cap)
        _check_multipliers_source_adjacent_positive(G, source, mult_cap)
    else:
        parametric = False
        lambdaRange = [0.0, 0.0]
    


    nodeNames, nodeDict, arcMatrix = _get_arcmatrix(
        G, const_cap, mult_cap, source, sink
    )

    c_input = _create_c_input(
        G, nodeDict, source, sink, arcMatrix, lambdaRange, roundNegativeCapacity
    )
    c_output = _create_c_output()

    _solve(c_input, c_output)

    breakpoints, cuts, info = _read_output(c_output, nodeNames)

    _cleanup(c_output)

    if not parametric:
        breakpoints = [None]

    return breakpoints, cuts, info

Simplify API

Separate parametric and non-parametric solver

Incorrect handling of parametric sink-adjacent arcs.

The following instance has a a breakpoint at 1.5. HPF doesn't find a breakpoint.

    G = DiGraph()
    G.add_nodes_from(['s', 't', 1, 2])
    G.add_edges_from([('s', 1), (1, 't'), (1, 2), (2, 't')])

    G['s'][1]["const"] = 2
    G[1][2]["const"] = 5
    G[1]['t']["const"] = 0
    G[2]['t']["const"] = float('inf')

    G['s'][1]["mult"] = 1
    G[1]['t']["mult"] = -1
    G[1][2]["mult"] = 0
    G[2]['t']["mult"] = 0

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.