#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
#  =================================================================
#           #     #                 #     #
#           ##    #   ####   #####  ##    #  ######   #####
#           # #   #  #    #  #    # # #   #  #          #
#           #  #  #  #    #  #    # #  #  #  #####      #
#           #   # #  #    #  #####  #   # #  #          #
#           #    ##  #    #  #   #  #    ##  #          #
#           #     #   ####   #    # #     #  ######     #
#
#        ---   The NorNet Testbed for Multi-Homed Systems  ---
#                        https://www.nntb.no
#  =================================================================
#
#  High-Performance Connectivity Tracer (HiPerConTracer)
#  Copyright (C) 2015-2022 by Thomas Dreibholz
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#  Contact: dreibh@simula.no

import os
import sys
import io
import datetime
import bz2
import shutil
import configparser
import operator
import psycopg2
import ssl
from collections import OrderedDict
from pymongo     import MongoClient
from ipaddress   import ip_address


# ###### Print log message ##################################################
def log(logstring):
   print('\x1b[32m' + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') + ': ' + logstring + '\x1b[0m');


# ###### Print warning message ##############################################
def warning(logstring):
   sys.stderr.write('\x1b[31m' + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') + ': WARNING: ' + logstring + '\x1b[0m\n');


# ###### Abort with error ###################################################
def error(logstring):
   sys.stderr.write(datetime.datetime.now().isoformat() + \
                    ' ===== ERROR: ' + logstring + ' =====\n')
   sys.exit(1)


# ###### Read input and prepare output ######################################

# Input types:
IT_NONE       = 0
IT_PING       = 1
IT_TRACEROUTE = 2

# Output types:
OT_POSTGRES   = 1
OT_MONGODB    = 2

def processInput(inputFile, outputType):
   inputType  = IT_NONE
   lineNumber = 0
   output     = {}
   hopCheck   = {}
   for inputLine in inputFile.readlines():
      lineNumber = lineNumber + 1
      tuples = inputLine.rstrip().split(' ')
      if len(tuples) > 0:
         # ====== Ping ======================================================
         if tuples[0] == '#P':
            if len(tuples) >= 7:
               # ------ Handle input ----------------------------------------
               if inputType == IT_NONE:
                  inputType = IT_PING
               elif inputType != IT_PING:
                  raise Exception('Multiple input types in the same file?!')

               sourceIP      = ip_address(tuples[1])
               destinationIP = ip_address(tuples[2])
               timeStamp     = int(tuples[3], 16)
               checksum      = int(tuples[4], 16)
               status        = int(tuples[5])
               rtt           = int(tuples[6])
               trafficClass  = 0
               if len(tuples) >= 8:   # TrafficClass was added in HiPerConTracer 1.4.0!
                  trafficClass  = int(tuples[7], 16)
               packetSize = 0
               if len(tuples) >= 9:   # PacketSize was added in HiPerConTracer 1.6.0!
                  packetSize  = int(tuples[8])

               assert ('0x' + tuples[3]) == hex(timeStamp)
               assert ('0x' + tuples[4]) == hex(checksum)
               # print('ping', sourceIP, destinationIP, timeStamp, status, rtt)

               # ------ Generate output -------------------------------------
               label = str(sourceIP) + '-' + str(destinationIP) + '-' + str(timeStamp)
               if outputType == OT_POSTGRES:
                  timeStampDT  = datetime.datetime(1970, 1, 1, 0, 0, 0, 0) + datetime.timedelta(microseconds = timeStamp)
                  timeStampStr = timeStampDT.strftime("%Y-%m-%dT%H:%M:%S.%f")
                  output[label] = '(' + \
                     '\'' + timeStampStr       + '\',' + \
                     '\'' + str(sourceIP)      + '\',' + \
                     '\'' + str(destinationIP) + '\',' + \
                     str(packetSize) + ',' + \
                     str(trafficClass) + ',' + \
                     str(status) + ',' + \
                     str(rtt) + \
                     ')'

               elif outputType == OT_MONGODB:
                  output[label] = OrderedDict([
                                     ( 'source',      sourceIP.packed      ),
                                     ( 'destination', destinationIP.packed ),
                                     ( 'pktsize',     int(packetSize)      ),
                                     ( 'tc',          int(trafficClass)    ),
                                     ( 'timestamp',   int(timeStamp)       ),
                                     ( 'checksum',    int(checksum)        ),
                                     ( 'status',      int(status)          ),
                                     ( 'rtt',         int(rtt)             ) ])

            else:
               raise Exception('Bad input for Ping in line ' + str(lineNumber))


         # ====== Traceroute ================================================
         elif tuples[0] == '#T':
            if len(tuples) >= 9:
               # ------ Handle input ----------------------------------------
               if inputType == IT_NONE:
                  inputType = IT_TRACEROUTE
               elif inputType != IT_TRACEROUTE:
                  raise Exception('Multiple input types in the same file?!')

               sourceIP      = ip_address(tuples[1])
               destinationIP = ip_address(tuples[2])
               timeStamp     = int(tuples[3], 16)
               roundNumber   = int(tuples[4])
               checksum      = int(tuples[5], 16)
               totalHops     = int(tuples[6])
               statusFlags   = int(tuples[7], 16)
               pathHashStr   = tuples[8]
               pathHash      = int(pathHashStr, 16)
               trafficClass  = 0
               if len(tuples) >= 10:   # TrafficClass was added in HiPerConTracer 1.4.0!
                  trafficClass  = int(tuples[9], 16)
               packetSize = 0
               if len(tuples) >= 11:   # PacketSize was added in HiPerConTracer 1.6.0!
                  packetSize  = int(tuples[10])

               assert ('0x' + tuples[3]) == hex(timeStamp)
               assert ('0x' + tuples[5]) == hex(checksum)
               assert ('0x' + tuples[7]) == hex(statusFlags)
               assert ('0x' + tuples[8]) == hex(pathHash)
               # print('traceroute', sourceIP, destinationIP, timeStamp, roundNumber, checksum, totalHops, statusFlags, pathHash)

               if outputType == OT_POSTGRES:
                  timeStampDT  = datetime.datetime(1970, 1, 1, 0, 0, 0, 0) +  datetime.timedelta(microseconds = timeStamp)
                  timeStampStr = timeStampDT.strftime("%Y%m%dT%H%M%S.%f")

               elif outputType == OT_MONGODB:
                  label = str(sourceIP) + '-' + str(destinationIP) + '-' + str(timeStamp) + '-' + str(roundNumber).zfill(3)
                  # MongoDB only supports signed integers:
                  if pathHash > 0x7FFFFFFFFFFFFFFF:
                     pathHash -= 0x10000000000000000
                  hopCheck[label] = 0
                  output[label] = OrderedDict([
                                     ( 'source',      sourceIP.packed      ),
                                     ( 'destination', destinationIP.packed ),
                                     ( 'pktsize',     int(packetSize)      ),
                                     ( 'tc',          int(trafficClass)    ),
                                     ( 'timestamp',   int(timeStamp)       ),
                                     ( 'round',       int(roundNumber)     ),
                                     ( 'checksum',    int(checksum)        ),
                                     ( 'totalHops',   int(totalHops)       ),
                                     ( 'statusFlags', int(statusFlags)     ),
                                     ( 'pathHash',    int(pathHash)        ),
                                     ( 'hops',        []                   ) ])

            else:
               raise Exception('Bad input for Traceroute in line ' + str(lineNumber))


         elif ((tuples[0] == '\t') and (inputType == IT_TRACEROUTE)):
            if len(tuples) >= 4:
               # ------ Handle input ----------------------------------------
               hopNumber = int(tuples[1])
               status    = int(tuples[2], 16)
               rtt       = int(tuples[3])
               hopIP     = ip_address(tuples[4])

               assert hopNumber <= totalHops
               assert ('0x' + tuples[2]) == hex(status)
               # print('\t', hopNumber, status, rtt, hopIP)

               # ------ Generate output -------------------------------------
               if outputType == OT_POSTGRES:
                  label = str(sourceIP) + '-' + str(destinationIP) + '-' + str(timeStamp) + '-' + str(roundNumber).zfill(3) + str(hopNumber).zfill(3)
                  output[label] = '(' + \
                     '\'' + timeStampStr       + '\',' + \
                     '\'' + str(sourceIP)      + '\',' + \
                     '\'' + str(destinationIP) + '\',' + \
                     str(packetSize) + ',' + \
                     str(trafficClass) + ',' + \
                     str(hopNumber) + ',' + \
                     str(totalHops) + ',' + \
                     str(status | statusFlags) + ',' + \
                     str(rtt) + ',' + \
                     '\'' + str(hopIP) + '\',' + \
                     'CAST(X\'' + pathHashStr + '\' AS BIGINT),' + \
                     str(roundNumber) + \
                     ')'

               elif outputType == OT_MONGODB:
                  label = str(sourceIP) + '-' + str(destinationIP) + '-' + str(timeStamp) + '-' + str(roundNumber).zfill(3)
                  assert(hopCheck[label] + 1 == hopNumber)   # Make sure that all hops are in order!
                  hopCheck[label] = hopNumber

                  output[label]['hops'].append(OrderedDict([
                     ( 'hop',    hopIP.packed ),
                     ( 'status', int(status)  ),
                     ( 'rtt',    int(rtt)     ) ]))

            else:
               raise Exception('Bad input for Traceroute in line ' + str(lineNumber))

         # ====== Error =====================================================
         else:
            raise Exception('Unexpected input in line ' + str(lineNumber))


   # ====== Sort result =====================================================
   resultsList = sorted(output.items(), key=operator.itemgetter(0))

   # ====== Generate output string ==========================================
   if len(resultsList) == 0:
      # There are no results => nothing to do!
      return None

   if outputType == OT_POSTGRES:
      outputString = ""
      if inputType == IT_PING:
         outputString = 'INSERT INTO Ping (TimeStamp,FromIP,ToIP,TC,Status,RTT) VALUES '
      elif inputType == IT_TRACEROUTE:
         outputString = 'INSERT INTO Traceroute (TimeStamp,FromIP,ToIP,TC,HopNumber,TotalHops,Status,RTT,HopIP,PathHash,Round) VALUES '

      firstItem = True
      for result in resultsList:
         if firstItem:
            outputString = outputString + '\n' + result[1]
            firstItem = False
         else:
            outputString = outputString + ',\n' + result[1]

      outputString = outputString + ';'
      return outputString

   elif outputType == OT_MONGODB:
      outputList = []
      for result in resultsList:
         outputList.append(result[1])
      return [ inputType, outputList ]



# ###### Main program #######################################################

# TEST:
#for absTransactionFile in [ 'x.bz2', 'y.bz2' ] :
   #inputFile = bz2.open(absTransactionFile, 'rt')
   #transactionContent = processInput(inputFile, OT_MONGODB)
   #print(transactionContent)
#sys.exit(1)


# ====== Handle arguments ===================================================
if len(sys.argv) < 2:
   error('Usage: ' + sys.argv[0] + ' database_configuration [-verbose]')

configFileName   = sys.argv[1]
transactionsPath = None
badFilePath      = None
outputType       = OT_POSTGRES
dbServer         = 'localhost'
dbPort           = 5432
dbUser           = 'importer'
dbPassword       = None
dbCAFile         = None
dbName           = 'pingtraceroutedb'

verboseMode      = False

i = 2
while i < len(sys.argv):
   if sys.argv[i] == '-verbose':
      verboseMode = True
   else:
      error('Bad argument: ' + sys.argv[i])
   i = i + 1


# ====== Get configuration ==================================================
parsedConfigFile = configparser.RawConfigParser()
parsedConfigFile.optionxform = str   # Make it case-sensitive!
try:
   parsedConfigFile.readfp(io.StringIO('[root]\n' + open(configFileName, 'r').read()))
except Exception as e:
    error('Unable to read database configuration file' +  sys.argv[1] + ': ' + str(e))
    sys.exit(1)

for parameterName in parsedConfigFile.options('root'):
   parameterValue = parsedConfigFile.get('root', parameterName)
   if parameterName == 'transactions_path':
      transactionsPath = parameterValue
   elif parameterName == 'bad_file_path':
      badFilePath = parameterValue
   elif parameterName == 'dbbackend':
      if parameterValue == 'PostgreSQL':
         outputType = OT_POSTGRES
      elif parameterValue == 'MongoDB':
         outputType = OT_MONGODB
      else:
         error('Unknown DB backend ' + parameterValue + ' in ' + sys.argv[1] + '!')
   elif parameterName == 'dbserver':
      dbServer = parameterValue
   elif parameterName == 'dbport':
      dbPort = parameterValue
   elif parameterName == 'dbuser':
      dbUser = parameterValue
   elif parameterName == 'dbpassword':
      dbPassword = parameterValue
   elif parameterName == 'dbcafile':
      dbCAFile = parameterValue
   elif parameterName == 'database':
      dbName = parameterValue
   else:
      error('Unknown parameter ' + parameterName + ' in ' + sys.argv[1] + '!')

if not os.path.exists(transactionsPath):
   error('Invalid transactions path ' + transactionsPath + '!')
try:
   os.makedirs(badFilePath, exist_ok=True)
except Exception as e:
   error('Unable to create bad file directory ' + badFilePath + ': ' + str(e))


# ====== Connect to the database ============================================
if outputType == OT_POSTGRES:
   try:
      if dbCAFile == "IGNORE":   # ------ Ignore TLS certificate ------------
         warning('TLS certificate check is turned off!')
         dbConnection = psycopg2.connect(host=str(dbServer), port=str(dbPort),
                                         user=str(dbUser),   password=str(dbPassword),
                                         dbname=str(dbName), sslmode='require')
      elif dbCAFile == "None":   # ------ Use default CA settings -----------
         dbConnection = psycopg2.connect(host=str(dbServer), port=str(dbPort),
                                         user=str(dbUser),   password=str(dbPassword),
                                         dbname=str(dbName), sslmode='verify-ca')
      else:   # ------ Use given CA -----------------------------------------
         dbConnection = psycopg2.connect(host=str(dbServer), port=str(dbPort),
                                         user=str(dbUser),   password=str(dbPassword),
                                         dbname=str(dbName), sslmode='verify-ca', sslrootcert=dbCAFile)
      dbConnection.autocommit = False
   except Exception as e:
      log('Unable to connect to the PostgreSQL database: ' + str(e))
      sys.exit(1)
   dbCursor = dbConnection.cursor()

elif outputType == OT_MONGODB:
   try:
      if dbCAFile == "IGNORE":   # ------ Ignore TLS certificate ------------
         warning('TLS certificate check is turned off!')
         dbConnection = MongoClient(host=str(dbServer), port=int(dbPort),
                                    ssl=True, ssl_cert_reqs=ssl.CERT_NONE)
      elif dbCAFile == "None":   # ------ Use default CA settings -----------
         dbConnection = MongoClient(host=str(dbServer), port=int(dbPort),
                                    ssl=True, ssl_cert_reqs=ssl.CERT_REQUIRED)
      else:   # ------ Use given CA, requires PyMongo >= 3.4! ---------------
         dbConnection = MongoClient(host=str(dbServer), port=int(dbPort),
                                    ssl=True, ssl_cert_reqs=ssl.CERT_REQUIRED,
                                    ssl_ca_certs=dbCAFile)
      db = dbConnection[str(dbName)]
      db.authenticate(str(dbUser), str(dbPassword), mechanism='SCRAM-SHA-1')
   except Exception as e:
      log('Unable to connect to the MongoDB database: ' + str(e))
      sys.exit(1)


# ====== Import transactions ================================================
fileNumber       = 0
goodTransactions = 0
badTransactions  = 0
firstTransaction = True

transactionFileList = sorted([ file for file in os.listdir(transactionsPath) if os.path.isfile(os.path.join(transactionsPath, file)) ])
for transactionFile in transactionFileList:
   absTransactionFile = os.path.join(transactionsPath, transactionFile)
   if os.path.isfile(absTransactionFile):
      if firstTransaction:
         firstTransaction = False
         log('Starting import of new transactions ...')

      # ------ Read transactions from file ----------------------------------
      fileNumber = fileNumber + 1
      if verboseMode == True:
         log('Importing ' + absTransactionFile +
             ' (' + str(fileNumber) + ' of ' + str(len(transactionFileList)) + ') ...')

      transactionContent = None
      try:
         inputFileStatus = os.stat(absTransactionFile)
         if inputFileStatus.st_size == 0:
            log('Transaction ' + transactionFile + ' has size 0 -> nothing to do')
            os.remove(absTransactionFile)
            continue
      except:
         pass

      try:
         inputFile = bz2.open(absTransactionFile, 'rt')
         transactionContent = processInput(inputFile, outputType)
      except Exception as e:
         log('Transaction ' + transactionFile + ' cannot be read: ' + str(e) + ' -> moving it to bad file directory')
         try:
            shutil.move(absTransactionFile, badFilePath)
         except Exception as e:
            error('Unable to move bad transaction ' + absTransactionFile + ' to ' + badFilePath + ': ' + str(e))

      if transactionContent != None:
         # ------ Commit transactions ---------------------------------------
         try:
            # print(transactionContent)

            if outputType == OT_POSTGRES:
               dbCursor.execute(transactionContent)
               dbConnection.commit()

            elif outputType == OT_MONGODB:
               if transactionContent[0] == IT_PING:
                  db['ping'].insert(transactionContent[1])
               elif transactionContent[0] == IT_TRACEROUTE:
                  db['traceroute'].insert(transactionContent[1])
               else:
                  raise Exception('Content type ' + str(transactionContent[0]) + ' is unknown!')

            goodTransactions = goodTransactions + 1

         # ------ Handle exceptions -> rollback -----------------------------
         except Exception as e:
            # ------ Connection is broken -----------------------------------
            if outputType == OT_POSTGRES:
               if ( (dbConnection.closed) or
                  ("SSL SYSCALL error" in str(e)) ):
                  # Need to check for "SSL SYSCALL errors", since psycopg2
                  # does not detect connection breaks properly.
                  # See https://bitbucket.org/zzzeek/sqlalchemy/issues/3021/ssl-eof-not-detected-as-disconnect-in
                  log('The database connection seems to be closed. Aborting import!')
                  break

            # ------ Other problem (e.g. bad SQL statements, etc. -----------
            log('Transaction ' + transactionFile + ' cannot be committed: ' + str(e) + ' -> moving it to bad file directory')
            if outputType == OT_POSTGRES:
               dbConnection.rollback()
            try:
               shutil.move(absTransactionFile, badFilePath)
            except Exception as e:
               error('Unable to move bad transaction ' + absTransactionFile + ' to ' + badFilePath + ': ' + str(e))
            badTransactions = badTransactions + 1

      else:
         log('Transaction ' + transactionFile + ' is empty -> nothing to do')

      try:
         if os.path.exists(absTransactionFile):
            os.remove(absTransactionFile)
      except Exception as e:
         error('Unable to remove completed transaction ' + absTransactionFile + ': ' + str(e))


# ====== All done! ==========================================================
log(str(goodTransactions) + ' transactions committed, ' + str(badTransactions) + ' were bad.')
