#!/usr/bin/env python
#
# This file is part of NNTSC.
#
# Copyright (C) 2013-2017 The University of Waikato, Hamilton, New Zealand.
#
# Authors: Shane Alcock
#          Brendon Jones
#
# All rights reserved.
#
# This code has been developed by the WAND Network Research Group at the
# University of Waikato. For further information please see
# http://www.wand.net.nz/
#
# NNTSC is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.
#
# NNTSC is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NNTSC; if not, write to the Free Software Foundation, Inc.
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
# Please report any bugs, questions or comments to contact@wand.net.nz
#

import sys
import argparse
from multiprocessing import *

from libnntsc.database import DBInsert
from libnntsc.configurator import *
from libnntsc.dberrorcodes import *
from libnntsc.importer import import_parsers
from libnntsc.exporter import NNTSCExporter
from libnntscclient.pidfile import PidFile
import libnntscclient.logger as logger
import daemon


class DataCollector():
    def start_module(self, name, mod, conf):
        try:
            self.db.connect_db(15)
            streams = self.db.select_streams_by_module(name)
        except DBQueryException as e:
            logger.log(e)
            return

        try:
            routingkey = self.queueid + "-" + name
            p = Process(name=name, target=mod.run_module, args=(streams, conf,
                    routingkey, 'nntsclive', self.queueid
                    ))
            p.daemon = True

            self.exporter.register_source(routingkey, self.queueid)
            self.processes.append(p)
        except Exception, e:
            raise

    def __init__(self, listen_port, config, backgrounded, exportonly,
            querytimeout):
        if backgrounded:
            logger.createLogger(True, "/tmp/nntsc.log", "NNTSC")
        self.config = config

        # Work out which modules to blacklist
        nntsc_conf = load_nntsc_config(self.config)
        enabled_modules = []
        for module in nntsc_conf.options('modules'):
            if nntsc_conf.get('modules', module) == "yes":
                enabled_modules.append(module)
        self.modules = import_parsers(enabled_modules)

        self.exporter = NNTSCExporter(listen_port)
        self.backgrounded = backgrounded
        self.exportonly = exportonly
        self.querytimeout = querytimeout

    def configure(self):
        nntsc_conf = load_nntsc_config(self.config)
        if nntsc_conf == 0:
            sys.exit(1)

        dbconf = get_nntsc_db_config(nntsc_conf)
        if dbconf == {}:
            sys.exit(1)

        if dbconf["name"] == "":
            logger.log("No database name specified in the NNTSC configuration file")
            sys.exit(1)

        self.db = DBInsert(dbconf["name"], dbconf["user"], dbconf["pass"],
                dbconf["host"])

        logger.log("Connecting to NNTSC database %s on host %s with user %s" % (dbconf["name"], dbconf["host"], dbconf["user"]))

        influxconf = get_influx_config(nntsc_conf)
        if influxconf == {}:
            sys.exit(1)

        if influxconf["useinflux"]:
            logger.log("Connecting to Influx database {} on host {} with user {}".format(
                influxconf["name"], influxconf["host"], influxconf["user"]))

        self.processes = []

        if not self.exportonly:
            queueid = get_nntsc_config(nntsc_conf, "liveexport", "queueid")
            if queueid == "NNTSCConfigError":
                sys.exit(1)
            if queueid == "NNTSCConfigMissing":
                queueid = "nntsclivequeue"

            self.queueid = queueid
        else:
            self.queueid = None

        if self.exporter.configure(self.config, self.querytimeout, self.queueid) == -1:
            sys.exit(1)
        self.config = nntsc_conf


    def run(self):

        logger.log("Starting NNTSC Collector")
        self.configure()
        if not self.exportonly:
            for m in self.modules.items():
                self.start_module(m[0], m[1], self.config)
            for p in self.processes:
                p.start()

        self.exporter.run()

        for p in self.processes:
            try:
                p.join()
            except:
                raise


    def get_processes(self):
        return self.processes

def cleanup():
    logger.log("Calling cleanup function\n")
    for p in dc.get_processes():
        p.terminate()
    exit(0)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("-C", "--configfile", help="Specify the location of the NNTSC config file")
    parser.add_argument("-p", "--port", help="The port to listen for incoming connections on (default: 61234)", default=61234, type=int)
    parser.add_argument("-b", "--background", help="Run as a daemon", action="store_true")
    parser.add_argument("-P", "--pidfile", help="PID file location (if running backgrounded)", default=None)
    parser.add_argument("-E", "--exportonly", help="Run the exporter only -- do not collect any new data", action="store_true")
    parser.add_argument("-T", "--querytimeout", help="Cancel any database queries that exceed this number of seconds", default=0, type=int)

    args = parser.parse_args()

    if args.configfile is None:
        print >> sys.stderr, "Must provide a config file using -C!"
        sys.exit(1)
    else:
        config = args.configfile

    listen_port = args.port

    if args.background:

        if args.pidfile is None:
            pidfile = None
        else:
            pidfile = PidFile(args.pidfile)

        context = daemon.DaemonContext()
        context.pidfile = pidfile

        with context:
            dc = DataCollector(listen_port, config, True, args.exportonly,
                    args.querytimeout)
            dc.run()

    else:
        dc = DataCollector(listen_port, config, False, args.exportonly,
                args.querytimeout)
        try:
            dc.run()
        except KeyboardInterrupt:
            cleanup()
        except:
            raise

# vim: set sw=4 tabstop=4 softtabstop=4 expandtab :
