import sys
import os


def getEdges(nodeList, shorten=False):
    edgeDict = {}

    for nodeEntry in nodeList:
        try:
            (node1, node2, count) = nodeEntry.strip().split("\t")
        except ValueError:
            continue

        if shorten:
            try:
                node1 = node1.split("_")[1]
            except IndexError:
                pass

            try:
                node2 = node2.split("_")[1]
            except IndexError:
                pass

        node1Detail = (node1, int(count))
        node2Detail = (node2, int(count))
        try:
            if node2Detail not in edgeDict[node1]:
                edgeDict[node1].append(node2Detail)
        except KeyError:
            edgeDict[node1] = [node2Detail]

        try:
            if node1Detail not in edgeDict[node2]:
                edgeDict[node2].append(node1Detail)
        except KeyError:
            edgeDict[node2] = [node1Detail]

    return edgeDict


def getEdgesFromFile(inFileName, shorten=False):

    infile = open(inFileName)
    edgeDict = getEdges(infile, shorten)
    infile.close()

    return edgeDict


def getOutputLine(currentNode, node, nodeCount):
    if nodeCount > 2:
        outputLine = '\t"%s" -- "%s" [ label = "%d", penwidth=%d, color="red", constraint=false] ; \n' % (currentNode, node, nodeCount, nodeCount)
    else:
        outputLine = '\t"%s" -- "%s" [ label = "%d", color="red", constraint=false] ; \n' % (currentNode, node, nodeCount)

    return outputLine


infilename = sys.argv[1]
outprefix = sys.argv[2]

shorten = False
if "-shorten" in sys.argv:
    shorten = True

edgeDict = getEdgesFromFile(infilename, shorten)

nodeList = edgeDict.keys()
seenNodeDict = {}
seenEdgeDict = {}
currentNodeList = []
currentEdgeList = []
treeList = []
localCount = []

outstat = open("%s.stats" % outprefix,"w")
outstat.write("#gID\tnodes\tedges\tweight\n")

def visitNodes(currentNode):
    if currentNode in seenNodeDict:
        return

    seenNodeDict[currentNode] = []
    for (node, nodeCount) in edgeDict[currentNode]:
        nodePair = [node, currentNode]
        nodePair.sort()
        if str(nodePair) not in seenEdgeDict:
            if node not in currentNodeList:
                currentNodeList.append(node)

            outputLine = getOutputLine(currentNode, node, nodeCount)
            currentEdgeList.append(outputLine)
            seenEdgeDict[str(nodePair)] = 0
            localCount[0] += nodeCount
            try:
                visitNodes(node)
            except:
                pass

print "getting trees"
for node in nodeList:
    if node not in seenNodeDict:
        currentNodeList = [node]
        currentEdgeList = []
        localCount = [0]
        outfile = open("%s.%s.gv" % (outprefix, node), "w")
        treeList.append(node)
        outfile.write("graph g%s {\n" % node)
        visitNodes(node)
        currentNodeList.sort()
        outfile.write('subgraph G0 {\n\t"%s" ' % currentNodeList[0])
        for anode in currentNodeList[1:]:
            outfile.write('-- "%s" ' % anode)

        outfile.write(" [ weight = 100 ] ;\n\tordering = out ;\n}\n")
        for line in currentEdgeList:
            outfile.write(line)

        outfile.write("}\n")
        outfile.close()
        outstat.write("%s\t%d\t%d\t%d\n" % (node, len(currentNodeList), len(currentEdgeList), localCount[0]))

print "generating pngs"
for node in treeList:
    output = os.popen("dot -Tpng %s.%s.gv > %s.%s.png" % (outprefix, node, outprefix, node))

outstat.close()