#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os
import sys
import random
import ConfigParser

import utils.terminal as terminal
import utils.fastalib as u
pp = terminal.pretty_print

bases = ['A', 'T', 'C', 'G', 'N']


class Configuration:
    def __init__(self, config):
        self.sanity_check(config)

        self.output_file = user_config.get('general', 'output_file')
        self.short_read_length = int(user_config.get('general', 'short_read_length'))
        self.error_rate = float(user_config.get('general', 'error_rate'))

        self.fasta_files = []
        self.fasta_files_dict = {}

        for section in [s for s in config.sections() if s != 'general']:
            alias = os.path.basename('.'.join(section.split('.')[:-1]))
            self.fasta_files.append(alias)
            self.fasta_files_dict[alias] = {'path': section,
                                            'alias': alias,
                                            'coverage': int(user_config.get(section, 'coverage'))}


    def sanity_check(self, config):
        # don't have any interest in implementing this right now.
        pass


def simulate_errors(error_rate, sequence):
    sequence_with_errors = ''
    num_errors = 0

    if error_rate > 0:
        threshold = 1000 * error_rate
        for i in range(0, len(sequence)):
            if random.randint(0, 1000) < threshold:
                sequence_with_errors += random.choice(bases)
                num_errors += 1 
            else:
                sequence_with_errors += sequence[i]
    else:
        sequence_with_errors = sequence

    return sequence_with_errors, num_errors
 

def main(config):
    run = terminal.Run(width = 15)
    progress = terminal.Progress()

    output = u.FastaOutput(config.output_file)

    for i in range(0, len(config.fasta_files)):
        f = config.fasta_files_dict[config.fasta_files[i]]

        x = config.short_read_length
        c = f['coverage']

        progress.new('Working on file %d of %d (%s) with expected coverage of %d' % (i + 1, len(config.fasta_files), f['alias'], c))

        fasta = u.SequenceSource(f['path'])
        total_num_errors = 0
        total_num_reads = 0
        while fasta.next():
            L = len(fasta.seq)

            av_num_short_reads_needed = L / x * c
            total_num_reads += av_num_short_reads_needed

            for i in range(0, av_num_short_reads_needed):
                if (i + 1) % 100 == 0:
                    progress.update('Entry %s :: %s nts :: reads %s of %s :: num errors: %s ...'\
                                                    % (pp(fasta.pos + 1), pp(len(fasta.seq)), 
                                                       pp(i + 1), pp(av_num_short_reads_needed),
                                                       pp(total_num_errors)))
                start_pos = random.randint(0, L - x)
                short_read, num_errors = simulate_errors(config.error_rate, fasta.seq[start_pos:start_pos + x])
                total_num_errors += num_errors

                output.write_id('%s_%d|source:%s|start:%d|stop:%d' % (f['alias'], i, fasta.id, start_pos, start_pos + x))
                output.write_seq(short_read)

        progress.end()
        run.info(f['alias'], '%s reads w/ %s errors (average rate of %.4f) generated for %sX average coverage.'\
                                    % (pp(total_num_reads),
                                       pp(total_num_errors),
                                       total_num_errors * 1.0 / (total_num_reads * x),
                                       pp(c),
                                       ))

    output.close()
    run.info('Fasta output', config.output_file) 


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Generate short reads from contigs')
    parser.add_argument('configuration', metavar = 'CONFIG_FILE', 
                                        help = 'Configuration file')

    args = parser.parse_args()
    user_config = ConfigParser.ConfigParser()
    user_config.read(args.configuration)

    config = Configuration(user_config)
    sys.exit(main(config))
