'''
Created on Oct 4, 2010

@author: sau
'''
import unittest
import os
from erange import peakstoregion

inFileName = "testPeaksToRegionInFile.txt"
outFileName = "testPeaksToRegionOutFile.txt"


class TestPeaksToRegion(unittest.TestCase):


    def setUp(self):
        self.inFile = open(inFileName, "w")
        self.inFile.write("stuff\tpeak1\tchr1\t1000\t1.3\n")
        self.inFile.write("stuff\tpeak2\tchr1\t800\t9.7\n")
        self.inFile.write("stuff\tpeak3\tchr2\t1000\t3.0\n")
        self.inFile.close()


    def tearDown(self):
        try:
            os.remove(outFileName)
        except OSError:
            pass

        try:
            os.remove(inFileName)
        except OSError:
            pass


    def testPeaksToRegion(self):
        peakstoregion.peakstoregion(inFileName, outFileName)
        output = open(outFileName)
        results = output.readlines()
        output.close()
        self.assertEquals(3, len(results))
        self.assertEquals("peak1\tchr1\t500\t1500\t1.3\n", results[0])
        self.assertEquals("peak2\tchr1\t300\t1300\t9.7\n", results[1])
        self.assertEquals("peak3\tchr2\t500\t1500\t3.0\n", results[2])


    def testMain(self):
        argv = ["peakstoregion", inFileName, outFileName]
        peakstoregion.main(argv)
        output = open(outFileName)
        results = output.readlines()
        output.close()
        self.assertEquals(3, len(results))
        self.assertEquals("peak1\tchr1\t500\t1500\t1.3\n", results[0])
        self.assertEquals("peak2\tchr1\t300\t1300\t9.7\n", results[1])
        self.assertEquals("peak3\tchr2\t500\t1500\t3.0\n", results[2])

        argv = ["peakstoregion", inFileName, outFileName, 600, 2, 3, 1, -1]
        peakstoregion.main(argv)
        output = open(outFileName)
        results = output.readlines()
        output.close()
        self.assertEquals(3, len(results))
        self.assertEquals("peak1\tchr1\t400\t1600\t1.3\n", results[0])
        self.assertEquals("peak2\tchr1\t200\t1400\t9.7\n", results[1])
        self.assertEquals("peak3\tchr2\t400\t1600\t3.0\n", results[2])


def suite():
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(TestPeaksToRegion))

    return suite


if __name__ == "__main__":
    #import sys;sys.argv = ['', 'Test.testName']
    unittest.main()