#!/usr/bin/env perl

# count_fasta_residues.pl -- Erich Schwarz <emsch@caltech.edu>, 2/25/2012. 
# Purpose: get basic statistics from FASTA files, both scaffolds and true contigs; counts residues directly.

use strict;
use warnings;
use Getopt::Long;
use Statistics::Descriptive;
use File::Basename;
use List::MoreUtils qw(uniq);

my @infiles      = ();
my $basename;
my $extra;
my $help;

my $nt_count     = 0;
my $N_count      = 0;
my $scaf_name    = q{};
my @scaf_sizes   = ();
my @contig_sizes = ();
my %scaf_seqs    = ();

GetOptions ( 'infiles=s{,}' => \@infiles,
             'basename'     => \$basename,
             'extra'        => \$extra,
             'help'         => \$help,   );

if ( $help or (! @infiles) ) { 
    die "Format: count_fasta_residues.pl\n",
        "    --infile|-i   <input stream/files>\n",
        "    --basename|-b [only give input file basenames]\n",
        "    --extra|-e    [give mean, mean, SDs of contigs/scaffolds]\n",
        "    --help|-h     [print this message]\n",
        ;
}

foreach my $infile (@infiles) { 
    $nt_count     = 0;  
    $N_count      = 0;  
    $scaf_name    = q{};
    @scaf_sizes   = ();
    @contig_sizes = (); 
    %scaf_seqs    = (); 

    my $INPUT_FILE;
    if ($infile eq '-') {
        # Special case: get the stdin handle
        $INPUT_FILE = *STDIN{IO};
    }
    else {
        # Standard case: open the file
        open $INPUT_FILE, '<', $infile or die "Can't open input file $infile. $!\n";
        $infile = basename $infile if $basename;
    }
    while (my $input = <$INPUT_FILE>) { 
        chomp $input;
        if ( $input =~ / \A > (\S+) /xms ) { 
            $scaf_name = $1;
            if ( exists $scaf_seqs{$scaf_name} ) { 
                die "Redundant scaffold name: $scaf_name\n";
            }
            if ($nt_count) { 
                push @scaf_sizes, $nt_count;
            }
            $nt_count = 0;
        }
        if ( ( $input !~ / \A > /xms) 
            and ($input =~ /\A \S+ \z/xms ) ) { 
            $input =~ s/\s//g;
            $input =~ s/[^a-zA-Z]//g;
            $scaf_seqs{$scaf_name} .= $input;
            $nt_count += length($input);
            $N_count  += ( $input =~ tr/n/n/ );
            $N_count  += ( $input =~ tr/N/N/ );
        }
    }
    # Finish off data at end of file.
    if ($nt_count) {
        push @scaf_sizes, $nt_count;
    }
    close $INPUT_FILE or die "Can't close filehandle to input file $infile. $!\n";

    foreach my $scaf_name2 (sort keys %scaf_seqs) { 
        $scaf_seqs{$scaf_name2} =~ tr/n/N/;
        my @contig_seqs = split /[N]+/, $scaf_seqs{$scaf_name2};
        foreach my $contig_seq (@contig_seqs) { 
            my $contig_len = length($contig_seq);
            push @contig_sizes, $contig_len;
        }
    }

    # Sort in ascending numerical order, so that homebrewed N50 subroutine can work.
    @scaf_sizes   = sort { $a <=> $b } @scaf_sizes;
    @contig_sizes = sort { $a <=> $b } @contig_sizes;

    my $stat1 = Statistics::Descriptive::Full->new();
    $stat1->add_data(@scaf_sizes);

    my $stat2 = Statistics::Descriptive::Full->new(); 
    $stat2->add_data(@contig_sizes);

    my $scaf_sum          = $stat1->sum();    # Total nt of sequence in scaffolds, including N residues.
    my $raw_scaf_sum      = $scaf_sum;
    my $raw_nonN_scaf_sum = $raw_scaf_sum - $N_count;
    my $perc_nonN         = ($raw_nonN_scaf_sum / $raw_scaf_sum) * 100;

    $scaf_sum         = commify($scaf_sum);
    my $nonN_scaf_sum = commify($raw_nonN_scaf_sum);
    $N_count          = commify($N_count);
    $perc_nonN        = sprintf("%.1f", $perc_nonN);

    my $raw_contig_sum = $stat2->sum();    # Total nt of sequence in contigs, *not* including N residues joining them in scaffolds.

    my $scaffolds = $stat1->count();  # Total number of scaffolds.
    $scaffolds    = commify($scaffolds);

    my $contigs   = $stat2->count();  # Total number of true contigs.
    $contigs      = commify($contigs);

    my $scaf_min     = $stat1->min();
    $scaf_min        = commify($scaf_min);

    my $contig_min   = $stat2->min();
    $contig_min      = commify($contig_min);

    my $scaf_max     = $stat1->max();
    $scaf_max        = commify($scaf_max);

    my $contig_max   = $stat2->max();
    $contig_max      = commify($contig_max);

    my $scaf_n50 = get_n50(\@scaf_sizes,$raw_scaf_sum);
    $scaf_n50    = sprintf("%.1f", $scaf_n50);
    $scaf_n50    = commify($scaf_n50);

    my $contig_n50 = get_n50( \@contig_sizes, $raw_contig_sum );
    $contig_n50    = sprintf("%.1f", $contig_n50);
    $contig_n50    = commify($contig_n50);

    my $scaf_mean    = q{};
    my $contig_mean    = q{};
    my $scaf_std_dev = q{};
    my $contig_std_dev = q{};
    my $scaf_median = q{};
    my $contig_median = q{};

    if ($extra) { 
        $scaf_mean   = $stat1->mean();
        $scaf_mean   = sprintf("%.1f", $scaf_mean);
        $scaf_mean   = commify($scaf_mean);

        $contig_mean = $stat2->mean();
        $contig_mean = sprintf("%.1f", $contig_mean);
        $contig_mean = commify($contig_mean);

        $scaf_std_dev   = $stat1->standard_deviation();
        $scaf_std_dev   = sprintf("%.1f", $scaf_std_dev);
        $scaf_std_dev   = commify($scaf_std_dev);

        $contig_std_dev = $stat2->standard_deviation();
        $contig_std_dev = sprintf("%.1f", $contig_std_dev);
        $contig_std_dev = commify($contig_std_dev);

        $scaf_median   = $stat1->median();
        $scaf_median   = sprintf("%.1f", $scaf_median);
        $scaf_median   = commify($scaf_median);

        $contig_median = $stat2->median();
        $contig_median = sprintf("%.1f", $contig_median);
        $contig_median = commify($contig_median);
    }

    print "\n";
    print "Sequence: $infile\n";

    print "\n";
    print "Total nt:           $scaf_sum\n";
    print "Scaffolds:          $scaffolds\n";
    print "Contigs:            $contigs\n";
    print "\n";

    print "ACGT nt:            $nonN_scaf_sum\n";
    print "N-res. nt:          $N_count\n";
    print "% non-N:            $perc_nonN\n";
    print "\n";

    print "Scaffold N50 nt:    $scaf_n50\n";
    print "Scaf. max. nt:      $scaf_max\n";
    print "Scaf. min. nt:      $scaf_min\n";
    print "\n";

    if ($extra) { 
        print "Scaf. mean nt:      $scaf_mean\n";
        print "Scaf. median nt:    $scaf_median\n";
        print "Scaf. s. dev. nt:   $scaf_std_dev\n";
        print "\n";
    }

    print "Contig N50:         $contig_n50\n";
    print "Contig max. nt:     $contig_max\n";
    print "Contig min. nt:     $contig_min\n";
    print "\n";

    if ($extra) {
        print "Contig mean nt:     $contig_mean\n";
        print "Cont. median nt:    $contig_median\n";
        print "Cont. s. dev. nt:   $contig_std_dev\n";
        print "\n";
    }
}

# Source -- Perl Cookbook 2.16, p. 84:
sub commify { 
    my $_text = reverse $_[0];
    $_text =~ s/ (\d{3}) 
                 (?=\d) 
                 (?!\d*\.)
               /$1,/xmsg;
    return scalar reverse $_text;
}

# Assumes array sorted in ascending values, with precomputed total value:
sub get_n50 { 
    my ($sizes_ref, $sum) = @_;
    my $midsum = ($sum / 2);
    my $tally = 0;
    my @_sizes = @{ $sizes_ref };
    foreach my $value (@_sizes) { 
        $tally += $value;
        if ($tally > $midsum) { 
            return $value;
        }
    }
    return;
}

