import khmer
import screed
from screed.fasta import fasta_iter

import khmer_tst_utils as utils

def teardown():
   utils.cleanup()

def load_fa_seq_names(filename):
    fp = open(filename)
    records = list(fasta_iter(fp))
    names = [ r['name'] for r in records ]
    return names

class Test_Filter(object):
    def test_abund(self):
        ht = khmer.new_hashtable(10, 4**10)

        filename = utils.get_test_data('test-abund-read.fa')
        outname = utils.get_temp_filename('test_abund.out')

        ht.consume_fasta(filename)
        ht.output_fasta_kmer_pos_freq(filename, outname)
        
        fd = open(outname, "r")

        output = fd.readlines()
        assert len(output) == 1

        output = output[0]
        output = output.strip().split()

        assert ['1']*(114-10+1) == output

        fd.close()

    def test_filter_limit_n(self):
        ht = khmer.new_hashtable(4, 4**4)

        filename = utils.get_test_data('simple_3.fa')
        outname = utils.get_temp_filename('test_filter.out')

        (total_reads, n_consumed) = ht.consume_fasta(filename)
        assert total_reads == 2, total_reads

        (total_reads, n_seq_kept) = \
            khmer.filter_fasta_file_limit_n(ht, filename,
                                            total_reads,
                                            outname, 2,
                                            7)


        assert total_reads == 2
        assert n_seq_kept == 1 

 
        (total_reads, n_seq_kept) = \
            khmer.filter_fasta_file_limit_n(ht, filename,
                                            total_reads,
                                            outname, 2,
                                            4)

        assert total_reads == 2
        assert n_seq_kept == 2


    def test_filter(self):
        ht = khmer.new_hashtable(10, 4**10)

        filename = utils.get_test_data('simple_1.fa')
        outname = utils.get_temp_filename('test_filter.out')

        (total_reads, n_consumed) = ht.consume_fasta(filename)
        assert total_reads == 3, total_reads
        assert n_consumed == 63, n_consumed

        (total_reads, n_seq_kept) = khmer.filter_fasta_file_any(ht, filename,
                                                                total_reads,
                                                                outname, 2)
        assert n_seq_kept == 2, n_seq_kept

        names = load_fa_seq_names(outname)
        assert names == ['1', '2']

    def test_filter_n(self):
        ht = khmer.new_hashtable(10, 4**10)

        filename = utils.get_test_data('simple_2.fa')
        outname = utils.get_temp_filename('test_filter.out')

        (total_reads, n_consumed) = ht.consume_fasta(filename)
        assert total_reads == 4, total_reads
        assert n_consumed == 63, n_consumed

        (total_reads, n_seq_kept) = khmer.filter_fasta_file_any(ht, filename,
                                                                total_reads,
                                                                outname, 1)
        assert n_seq_kept == 3, n_seq_kept

        names = load_fa_seq_names(outname)
        assert names == ['1', '2', '3']

    def test_consume_build_readmask(self):
        ht = khmer.new_hashtable(10, 4**10)

        filename = utils.get_test_data('simple_2.fa')
        outname = utils.get_temp_filename('test_filter.out')

        # sequence #4 (index 3) is bad; the new readmask should have that.
        x = ht.consume_fasta_build_readmask(filename)
        (total_reads, n_consumed, readmask) = x
        
        assert total_reads == 4, total_reads
        assert n_consumed == 63, n_consumed
        assert readmask.get(0)
        assert readmask.get(1)
        assert readmask.get(2)
        assert not readmask.get(3)
        
    def test_consume_update_readmask(self):
        ht = khmer.new_hashtable(10, 4**10)

        filename = utils.get_test_data('simple_2.fa')
        outname = utils.get_temp_filename('test_filter.out')

        readmask = khmer.new_readmask(4)

        # sequence #4 (index 3) is bad; the new readmask should have that.
        (total_reads, n_consumed) = ht.consume_fasta(filename, 0, 0,
                                                     readmask, True)
        assert total_reads == 4, total_reads
        assert n_consumed == 63, n_consumed
        assert readmask.get(0)
        assert readmask.get(1)
        assert readmask.get(2)
        assert not readmask.get(3)

    def test_consume_no_update_readmask(self):
        ht = khmer.new_hashtable(10, 4**10)

        filename = utils.get_test_data('simple_2.fa')
        outname = utils.get_temp_filename('test_filter.out')

        readmask = khmer.new_readmask(4)

        # sequence #4 (index 3) is bad; the new readmask should NOT have that.
        (total_reads, n_consumed) = ht.consume_fasta(filename, 0, 0,
                                                     readmask, False)
        assert total_reads == 4, total_reads
        assert n_consumed == 63, n_consumed
        assert readmask.get(0)
        assert readmask.get(1)
        assert readmask.get(2)
        assert readmask.get(3)          # NOT updated

    def test_readmask_1(self):
        filename = utils.get_test_data('simple_2.fa')
        outname = utils.get_temp_filename('test_filter.out')

        readmask = khmer.new_readmask(4)
        readmask.set(1, False)
        readmask.set(2, False)
        readmask.set(3, False)

        readmask.filter_fasta_file(filename, outname)

        names = load_fa_seq_names(outname)
        assert names == ['1'], names

    def test_readmask_2(self):
        filename = utils.get_test_data('simple_2.fa')
        outname = utils.get_temp_filename('test_filter.out')

        readmask = khmer.new_readmask(4)
        readmask.set(0, False)
        readmask.set(1, True)
        readmask.set(2, False)
        readmask.set(3, False)

        readmask.filter_fasta_file(filename, outname)

        names = load_fa_seq_names(outname)
        assert names == ['2'], names

def test_filter_sodd():
   K = 32
   HASHTABLE_SIZE=int(8e7)
   N_HT = 4
   MAX_SODD=3
   
   ht = khmer.new_hashbits(K, HASHTABLE_SIZE, N_HT)
   filename = utils.get_test_data('../../data/high-sodd.fa')

   ht.consume_fasta(filename)

   seq = "CGTTAGTTGCGGTGCCGACCGGCAAACTTGGTTTTGCCAAAAATTTTTACAGTTAGAAATTATTCACAAAGTTGCACCGGAATTCGGTTACAAACGTCATTCTAACTAAT"
   trim_seq, trim_at = ht.trim_on_sodd(seq, MAX_SODD)
   assert trim_seq == "CGTTAGTTGCGGTGCCGACCGGCAAACTTGGT"

   seq = "ACAAAATTCCACATATAGTCATAATTGTGGGCAATTTTCGTCCCAAATTAGTTAGAATGACGTTTGTAACCGAATTCCGGTGCAACTTTGTGAATAATTTCTAACTGTAAAAAT"
   trim_seq, trim_at = ht.trim_on_sodd(seq, MAX_SODD)
   assert trim_seq == "ACAAAATTCCACATATAGTCATAATTGTGGGCAATT"

   seq = "GCACGCAGATCGGAAGAGCGTCGTGTAGGGAAAGAGTGTAGATCTCGGTG"
   trim_seq, trim_at = ht.trim_on_sodd(seq, MAX_SODD)
   assert trim_seq == seq
