Source code for mllaunchpad.cli

"""This module provides the command line interface for ML Launchpad"""

# Stdlib imports
import json
import sys
from logging import Logger
from typing import Dict

# Third-party imports
import click
from flask import Flask

# Project imports
import mllaunchpad as mllp
from mllaunchpad import logutil
from mllaunchpad.api import ModelApi, generate_raml

# Fix for click using the wrong name if run using `python -m mllaunchpad`
# Also see:
sys.argv[0] = mllp.__name__

# Adapted from
[docs]class AliasedGroup(click.Group): """Commands can be abbreviated, e.g. t or tr for train, a for api, etc."""
[docs] def get_command(self, ctx, cmd_name): rv = click.Group.get_command(self, ctx, cmd_name) if rv is not None: return rv matches = [ x for x in self.list_commands(ctx) if x.startswith(cmd_name) ] if not matches: return None elif len(matches) == 1: click.echo( "Command {} matches {}".format(cmd_name, matches[0]), err=True ) return click.Group.get_command(self, ctx, matches[0]) "Command {} is ambiguous: {}".format( cmd_name, ", ".join(sorted(matches)) ) )
# TODO: Migrate to @dataclass when dropping support for Python 3.6
[docs]class Settings: def __init__(self): self.verbose: bool = False self.logger: Logger = None self.conf_file: str = None self._config: Dict = None @property def config(self): if not hasattr(self, "_config") or not self._config: if self.conf_file: self._config = mllp.get_validated_config(self.conf_file) else: self._config = mllp.get_validated_config() return self._config
pass_settings = click.make_pass_decorator(Settings, ensure=True) cls=AliasedGroup, context_settings=dict(help_option_names=["-h", "--help"]) ) @click.version_option(prog_name="ML Launchpad") @click.option("--verbose", "-v", is_flag=True, help="Print debug messages.") @click.option( "--config", "-c", type=click.Path(exists=True), help="Use this configuration file. [default: look for env var LAUNCHPAD_CFG or ./LAUNCHPAD_CFG.yml]", ) @click.option( "--log-config", "-l", type=click.Path(exists=True), help="Use this log configuration file. [default: look for env var LAUNCHPAD_LOG or ./LAUNCHPAD_LOG.yml]", ) @pass_settings def main(settings, log_config, config, verbose): """Train, test or run a config file's model.""" # Initialize logging before any library code so that mllp can log stuff if log_config: settings.logger = logutil.init_logging(log_config, verbose=verbose) else: settings.logger = logutil.init_logging(verbose=verbose) settings.verbose = verbose settings.conf_file = config @main.command() @pass_settings def train(settings): """Run training, store created model and metrics.""" _, metrics = mllp.train_model(settings.config) print(metrics) @main.command() @pass_settings def retest(settings): """Retest existing model, update metrics.""" metrics = mllp.retest(settings.config) print(metrics) @main.command() @pass_settings def api(settings): """Run API server in unsafe debug mode.""" settings.logger.warning( "Starting Flask debug server. In production, please " "use a WSGI server, e.g.\n" "'export LAUNCHPAD_CFG=addition_cfg.yml'\n" "'gunicorn -w 4 -b mllaunchpad.wsgi:application'" ) app = Flask(__name__, root_path=settings.config["api"].get("root_path")) ModelApi(settings.config, app, debug=True) # Flask apps must not be run in debug mode in production, because this allows for arbitrary code execution. # We know that and advise the user that this is only for debugging, so this is not a security issue (marked nosec): # nosec @main.command() @click.argument("json-file", type=click.File("r"), default=sys.stdin) @pass_settings def predict(settings, json_file): """Run prediction on features from JSON file ( - for stdin). \b Example JSON: { "petal.width": 1.4, "petal.length": 2.0, "sepal.width": 1.8, "sepal.length": 4.0 } """ arg_dict = json.load(json_file) output = mllp.predict( settings.config, arg_dict=arg_dict, use_live_code=True ) print(output) @main.command(name="generate-raml") @click.argument("datasource-name", type=str, required=True) @pass_settings def cli_generate_raml(settings, datasource_name): """Generate and print RAML template from DATASOURCE_NAME. The datasource named DATASOURCE_NAME in the config will be used to create the API's query parameters (from columns), types, and examples. """ print(generate_raml(settings.config, data_source_name=datasource_name)) if __name__ == "__main__": # PyLint (used by Codacity) does not know that the signature of `main()` # is being changed by Click at runtime. Disable the warning. sys.exit( main() # pylint: disable=no-value-for-parameter; # pragma: no cover )