import logging
from logging.handlers import SysLogHandler

LOGFMT=logging.Formatter("%(asctime)s [ltsp-cluster-agent] %(levelname)s [%(name)s] %(message)s")
syslog = SysLogHandler()
syslog.setFormatter(LOGFMT)
logging.root.addHandler(syslog)

import SocketServer
import BaseHTTPServer
import SimpleHTTPServer
import SimpleXMLRPCServer

ssl_version=0
try:
    import ssl
    ssl_version=1
except:
    from OpenSSL import SSL

import socket, os, sys, hashlib
from base64 import b64decode
from plugin import Plugin
from configobj import ConfigObj

def get_config_path(conf, *path):
    res = conf
    try:
        for p in path:
            res = res[p]
    except (KeyError, TypeError), e:
        return None
    return res

class AuthorizationFailure(Exception):
    pass

class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
    def __init__(self, server_address, SecureXMLRPCRequestHandler, logRequests=False):
        """Initialize SSL"""
        # Deal with IPv6
        addrinfo = socket.getaddrinfo(server_address[0], server_address[1])
        for entry in addrinfo:
            if entry[0] == socket.AF_INET6:
                self.address_family = socket.AF_INET6

        self.logRequests = logRequests
        SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self,False,None)
        SocketServer.BaseServer.__init__(self, server_address, SecureXMLRPCRequestHandler)

        if ssl_version == 1:
            self.socket = ssl.wrap_socket(socket.socket(self.address_family,self.socket_type),get_config_path(config,"server","sslkey"),get_config_path(config,"server","sslcert"))
        else:
            ctx = SSL.Context(SSL.SSLv23_METHOD)
            ctx.use_privatekey_file (get_config_path(config,"server","sslkey"))
            ctx.use_certificate_file(get_config_path(config,"server","sslcert"))
            self.socket = SSL.Connection(ctx, socket.socket(self.address_family,self.socket_type))

        self.server_bind()
        self.server_activate()

class SecureXMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
    """XML-RPC request handle"""
    def handle_one_request(self):
        try:
            SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.handle_one_request(self)
        except:
            pass

    def setup(self):
        self.connection = self.request
        self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
        self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)

    def do_POST(self):
        try:
            data = self.rfile.read(int(self.headers["content-length"]))
            response = self.server._marshaled_dispatch(
                    data, getattr(self, '_dispatch', None)
                )
        except AuthorizationFailure:
            self.send_error(401, 'Authentication failed')
            self.end_headers()

        except:
            self.send_response(500)
            self.end_headers()
        else:
            self.send_response(200)
            self.send_header("Content-type", "text/xml")
            self.send_header("Content-length", str(len(response)))
            self.end_headers()
            self.wfile.write(response)
            self.wfile.flush()
            self.connection.shutdown()

    def _dispatch(self, method, params):
        logging.debug("Remote call on %s with params %s", method, str(params))
        # Methods will be [module[.submodule].]function, remove the function part
        plugin = method.rpartition(".")[0]

        # Export the client address to the plugin
        if plugin in plugins:
            plugins[plugin].client_address=self.client_address

        # Only use the module, not submodule
        plugin = plugin.partition(".")[0]


        try:
            (basic, _, encoded) = self.headers.get('Authorization').partition(' ')
            if basic == 'Basic':
                (username, _, password) = b64decode(encoded).partition(':')
                password = hashlib.sha1(password).hexdigest()

                if not authenticate(plugin, username, password):
                    raise AuthorizationFailure("Wrong credentials")
            else:
                raise AuthorizationFailure("Unsupported Authorization scheme")

        except Exception, e:
            raise AuthorizationFailure(str(e))

        return self.server._dispatch(method, params)

# Find and load a configuration file
if os.path.exists(os.getcwd()+'/config/agent.conf'):
    config = ConfigObj(os.getcwd()+'/config/agent.conf')
elif os.path.exists('/etc/ltsp/agent.conf'):
    config = ConfigObj('/etc/ltsp/agent.conf')
elif os.path.exists('%s\\ltsp\\config\\agent.conf' % sys.prefix):
    config = ConfigObj('%s\\ltsp\\config\\agent.conf' % sys.prefix)
else:
    logging.error("Unable to find configuration file")
    sys.exit(1)

# Find plugin directories and add them to the path
plugin_dirs=[]
if os.path.exists(os.getcwd()+'/plugins'):
    sys.path.insert(0, os.getcwd()+'/plugins')
    plugin_dirs.append(os.getcwd()+'/plugins')
elif get_config_path(config,"server","plugindir") and os.path.exists(get_config_path(config,"server","plugindir")):
    sys.path.insert(0, get_config_path(config,"server","plugindir"))
    plugin_dirs.append(get_config_path(config,"server","plugindir"))
elif os.path.exists('/usr/share/ltsp-agent/plugins/'):
    sys.path.insert(0, '/usr/share/ltsp-agent/plugins/')
    plugin_dirs.append('/usr/share/ltsp-agent/plugins/')
elif os.path.exists('%s\\ltsp\\plugins' % sys.prefix):
    sys.path.insert(0, '%s\\ltsp\\plugins' % sys.prefix)
    plugin_dirs.append('%s\\ltsp\\plugins' % sys.prefix)
else:
    logging.error("Unable to find plugin directory, starting without plugins")

# Look for plugins
plugins={}
for plugin_dir in plugin_dirs:
    for entry in os.listdir(plugin_dir):
        plugins[entry]=None

def authenticate(plugin, username, password):
    testuser = get_config_path(plugins[plugin].config, "auth", username)
    testpw = get_config_path(plugins[plugin].config, "auth", username, "password")

    if testuser is None or testpw is None:
        testuser = get_config_path(config, "auth", username)
        testpw = get_config_path(config, "auth", username, "password")

    if testuser is not None:
        testuser = username

    res = (username == testuser and password == testpw)
    logging.debug("Auth %s by '%s' on plugin '%s'", res and "success" or "failure", username, plugin)
    return res


def server():
    server_address = (get_config_path(config,"server","bindaddr"),int(get_config_path(config,"server","bindport")))
    server = SecureXMLRPCServer(server_address, SecureXMLRPCRequestHandler)

    # Get all plugins and for each xmlrpc function, anounce it
    for plugin in plugins:
        try:
            plugins[plugin]=getattr(__import__(plugin, {}, {}, [plugin]),plugin)()

            # Export a few functions and variables to the plugin
            plugins[plugin].serverconfig=config
            plugins[plugin].get_config_path=get_config_path
            plugins[plugin].plugins=plugins
            plugins[plugin].init_plugin()
        except AttributeError, err:
            logging.error("Invalid plugin %s (%s)", plugin, str(err))
        except Exception, err:
            logging.error("Unable to load plugin %s (%s)", plugin, str(err)[:100])
        else:
            for function in plugins[plugin].rpc_functions():
                server.register_function(
                    getattr(plugins[plugin], function),
                    plugin + '.' + function
                )
            plugins[plugin].start_threads()

    sa = server.socket.getsockname()
    logging.info("Serving HTTPS on %s port %s", sa[0], sa[1])
    try:
        server.serve_forever()
    finally:
        for plugin in plugins.values():
            plugin.stop_threads()

import os
if os.environ.get("DEBUG", None) is not None:
    level = os.environ["DEBUG"]
    level = level.upper()
    if not level in logging._levelNames:
        try:
            level = int(level)
        except ValueError:
            level = logging.NOTSET
    else:
        level = logging._levelNames[level]
    logging.root.setLevel(level)
    sysout = logging.StreamHandler()
    sysout.setFormatter(LOGFMT)
    logging.root.addHandler(sysout)

def start_server():
    try:
        server()
    except KeyboardInterrupt:
        logging.info("Done")
        os._exit(1)
