##################################
#                                #
# Last modified 2021/12/31       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy
import gzip
from sets import Set

def IsLeap(year):

    if year % 4 == 0:
        if year % 100 != 0:
            return True
        else:
            if year % 400 == 0:
                return True
            else:
                return False
    else:
        return False

def GetDate(date,FY):

    MonthLeapDict = {1:31, 2:29, 3:31, 4:30, 5:31, 6:30, 7:31, 8:31, 9:30, 10:31, 11:30, 12:31}
    MonthDict = {1:31, 2:28, 3:31, 4:30, 5:31, 6:30, 7:31, 8:31, 9:30, 10:31, 11:30, 12:31}

    month = int(date.split('/')[0])
    day = int(date.split('/')[1])
    year = int(date.split('/')[2])

    D = 0
    if IsLeap(year):
        for i in range(month):
            D += MonthLeapDict[month]
        D += day
    else:
        for i in range(month):
            D += MonthDict[month]
        D += day

    for i in range(FY,year):
        if IsLeap(i):
            D += 366
        else:
            D += 365

    return D        

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s data.tsv ID_fieldID dates_field_ID first_year minInterval outfilename' % sys.argv[0]
        print '\tassumed date format: 2/15/2021' 
        sys.exit(1)

    datafilename = sys.argv[1]
    IDFieldID = int(sys.argv[2])
    dateFieldID = int(sys.argv[3])
    first_year = int(sys.argv[4])
    minInterval = int(sys.argv[5])
    outfilename = sys.argv[6]

    outfile = open(outfilename, 'w')

    PatientDict = {}
    if datafilename == '-':
        lineslist  = sys.stdin
    else:
        if datafilename.endswith('.gz'):
            lineslist  = gzip.open(datafilename)
        else:
            lineslist  = open(datafilename)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        ID = int(fields[IDFieldID])
        date = GetDate(fields[dateFieldID],first_year)
        if PatientDict.has_key(ID):
            pass
        else:
            PatientDict[ID] = []
        PatientDict[ID].append(date)
        PatientDict[ID] = list(Set(PatientDict[ID]))

    IDs = PatientDict.keys()
    IDs.sort()

    outline = '#ID\tprevious_infection\treinfection\treinfection_interval\treinfection_week'
    outfile.write(outline + '\n')

    for ID in IDs:
        PatientDict[ID].sort()
        print ID, PatientDict[ID]
        for i in range(1,len(PatientDict[ID])):
            reinfection_interval = PatientDict[ID][i] - PatientDict[ID][i-1]
            if reinfection_interval < minInterval:
                continue
            outline = str(ID) + '\t' + str(PatientDict[ID][i-1])  + '\t' + str(PatientDict[ID][i]) + '\t' + str(reinfection_interval) + '\t' + str(PatientDict[ID][i]/7)
            outfile.write(outline + '\n')

    outfile.close()
        
run()

