#
#  combinerds.py
#  ENRAGE
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import ReadDataset

print "combinerds: version 1.2"


def main(argv=None):
    if not argv:
        argv = sys.argv

    if len(argv) < 2:
        print 'usage: python %s destinationRDS inputrds1 [inputrds2 ....] [-table table_name] [--init] [--initrna] [--index] [--cache pages]' % argv[0]
        #print '\nwhere the optional metadata name::value pairs are added to the existing dataset\n'
        sys.exit(1)

    doCache = False
    cachePages = -1
    if '--cache' in argv:
        doCache = True
        try:
            cachePages =  int(argv[sys.argv.index('-cache') + 1])
        except: 
            pass

    datafile = argv[1]
    infileList = []
    for index in range(2, len(argv)):
        if argv[index][0] == '-':
            break
        infileList.append(sys.argv[index])

    print "destination RDS: %s" % datafile

    if '--initrna' in argv:
        rds = ReadDataset.ReadDataset(datafile, initialize=True, datasetType='RNA')
    elif '--init' in argv:
        rds = ReadDataset.ReadDataset(datafile, initialize=True)

    withFlag = ''
    if '--flag' in argv:
        withFlag = argv[sys.argv.index('-flag') + 1]
        print "restrict to flag = %s" % withFlag

    rds = ReadDataset.ReadDataset(datafile, verbose=True, cache=doCache)

    if cachePages > rds.getDefaultCacheSize():
        rds.setDBcache(cachePages)
        cacheVal = cachePages
    else:
        cacheVal = rds.getDefaultCacheSize()

    doIndex = False
    if '--index' in argv:
        doIndex = True

    tableList = []
    if '--table' in argv:
        tableList.append(argv[argv.index('-table') + 1])
    else:
        tableList = rds.getTables()

    combinerds(datafile, rds, infileList, cacheVal, tableList, withFlag, doIndex, doCache)


def combinerds(datafile, rds, infileList, cacheVal, tableList=[], withFlag="", doIndex=False, doCache=False):
    metaDict = rds.getMetadata()
    if "numberImports" not in metaDict:
        origIndex = 0
        rds.insertMetadata([("numberImports", str(0))])
    else:
        origIndex = int(metaDict["numberImports"])

    index = origIndex
    for inputfile in infileList:
        asname = "input" + str(index)
        rds.attachDB(inputfile,asname)
        for table in tableList:
            print "importing table %s from file %s" % (table, inputfile)
            ascols = "*"
            if table == "uniqs":
                ascols = "NULL, '%s' || readID, chrom, start, stop, sense, weight, flag, mismatch" % asname
            elif table == "multi":
                ascols = "NULL, '%s' || readID, chrom, start, stop, sense, weight, flag, mismatch" % asname
            elif table == "splices":
                ascols = "NULL, '%s' || readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch" % asname
            elif table == "metadata":
                ascols = "name, value || ' (import_%d)'" % index
                rds.importFromDB(asname, table, ascols)

            if table != "metadata":
                rds.importFromDB(asname, table, ascols, withFlag)

        rds.detachDB(asname)
        rds.insertMetadata([("import_" + str(index), "%s %s" % (inputfile, str(tableList)))])
        index += 1

    rds.updateMetadata("numberImports", index, origIndex)
    if doIndex:
        print "building index...."
        if cacheVal > 0:
            rds.buildIndex(cacheVal)
        else:
            rds.buildIndex()

    if doCache:
        rds.saveCacheDB(datafile)


if __name__ == "__main__":
    main(sys.argv)