Source code for mllaunchpad.api

"""This module contains functionality for generic creation
   and handling of RESTful APIs for Machine Learning Models.
   Among others, it handles parsing the RAML definition,
   and validating parameters.
"""

# Stdlib imports
import logging
import re
from collections import OrderedDict

# Third-party imports
import ramlfications
from flask_restful import Api, Resource, reqparse
from werkzeug.datastructures import FileStorage

# Project imports
from mllaunchpad import model_actions, resource


logger = logging.getLogger(__name__)

# TODO: factor out datasource/model preparation code -- use model_actions instead of low-level functionality


[docs]def get_api_base_url(config): api_name = config["api"]["name"] api_version = _get_major_api_version(config) return "/{}/{}".format(api_name, api_version)
def _get_major_api_version(config): match = re.match(r"\d+", config["model"]["version"]) if match is None: raise ValueError( "Model version in configuration is malformed. Expected x.y.z, got {}".format( config["model"]["version"] ) ) return "v{}".format(match.group(0)) def _load_raml(config): api_config = config["api"] raml_file = api_config["raml"] logger.debug("Reading RAML file %s", raml_file) raml = ramlfications.parse(api_config["raml"]) conf_version = _get_major_api_version(config) if raml.version != conf_version: raise ValueError( "API version in RAML {} does not match API version in config {}".format( raml.version, conf_version ) ) return raml _type_lookup = { "number": float, "integer": int, "string": str, "boolean": bool, "array": list, # https://medium.com/raml-api/array-enumeration-in-raml-b69950c75bb3 # RAML Types: any, object, array, union via type expression, # one of the following scalar types: number, boolean, string, # date-only, time-only, datetime-only, datetime, file, integer, or nil } def _create_request_parser(resource_obj): """We only use query_params and form_params for now (no custom headers, body, etc.). Note that they are used interchangeably in code (so e.g. even if RAML requires only form_params, in reality we also accept the same params as query_params). """ all_params = (resource_obj.query_params or []) + ( resource_obj.form_params or [] ) added_arguments = set() parser = reqparse.RequestParser(bundle_errors=True) for p in all_params: if p.name in added_arguments and not p.repeat: raise ValueError( "Cannot handle RAML with multiple parameters sharing same name {}".format( p.name ) ) # https://help.mulesoft.com/s/article/Repeat-query-parameters-using-RAML-1-0 try: param_type, _ = p.type.split("[]") is_array = True except ValueError: param_type = p.type is_array = False is_array = is_array or p.repeat parser.add_argument( p.name, type=_type_lookup[param_type], required=p.required, default=p.default, action="append" if is_array else "store", choices=p.enum, help=str(p.description) + " - Error: {error_msg}", ) added_arguments.add(p.name) if ( resource_obj.uri_params and resource_obj.uri_params[-1].name in added_arguments ): raise ValueError( 'Resource URI parameter in RAML "{}" must not have same name as a parameter'.format( resource_obj.uri_params[-1].name ) ) if ( resource_obj.body and resource_obj.body[0].mime_type == "multipart/form-data" ): # todo: how can we make sure the file mime type is correct? parser.add_argument("file", type=FileStorage, location="files") return parser def _get_resources(raml): """Gets relevant resources from RAML""" # only dealing with "get" method resources for now usable_methods = ["get", "post"] usable_rs = [ r for r in raml.resources if r.method in usable_methods ] # r.path == name and rs_without_resource_id = [ r for r in usable_rs if not r.uri_params and not r.body ] rs_with_resource_id = [r for r in usable_rs if r.uri_params] rs_file_upload = [ r for r in usable_rs if r.body and r.body[0].mime_type == "multipart/form-data" ] if ( len(usable_rs) == 0 or len(usable_rs) > 3 or len(rs_without_resource_id) > 1 or len(rs_with_resource_id) > 1 or len(rs_file_upload) > 1 ): raise ValueError( ( "RAML must contain one to three resources with a method of '{}'. " "At most one resource each with and without uri parameter (resource id) " "or one file upload resource.\n" "There are {} resources with matching methods. Resources in RAML: {}" ).format(usable_methods, len(usable_rs), raml.resources) ) res_normal = rs_without_resource_id[0] if rs_without_resource_id else None res_with_id = rs_with_resource_id[0] if rs_with_resource_id else None res_file = rs_file_upload[0] if rs_file_upload else None return res_normal, res_with_id, res_file
[docs]class QueryResource(Resource): # Adapted from https://flask-restful.readthedocs.io/en/latest/quickstart.html def __init__(self, model_api_obj, parser): self.model_api = model_api_obj self.parser = parser
[docs] def get(self): args = self.parser.parse_args( strict=True ) # treats query_params and form_params as interchangeable logger.debug("Received GET request with arguments: %s", args) return self.model_api.predict_using_model(args)
[docs]class GetByIdResource(Resource): def __init__(self, model_api_obj, parser, id_name): self.model_api = model_api_obj self.parser = parser self.id_name = id_name
[docs] def get(self, some_resource_id): args = self.parser.parse_args( strict=True ) # treats query_params and form_params as interchangeable args[self.id_name] = some_resource_id logger.debug( "Received GET request for %s %s with arguments: %s", self.id_name, some_resource_id, args, ) return self.model_api.predict_using_model(args)
[docs]class QueryOrFileUploadResource(Resource): def __init__(self, model_api_obj, query_parser=None, file_parser=None): self.model_api = model_api_obj self.query_parser = query_parser self.file_parser = file_parser
[docs] def get(self): args = self.query_parser.parse_args( strict=True ) # treat query_params and form_params as interchangeable logger.debug("Received GET request with arguments: %s", args) return self.model_api.predict_using_model(args)
[docs] def post(self): if self.file_parser: args = self.file_parser.parse_args(strict=True) file_storage_obj = args["file"] logger.debug( "Received POST request with file %s of mimetype %s", file_storage_obj.filename, file_storage_obj.mimetype, ) else: # treat query_params and form_params as interchangeable args = self.query_parser.parse_args(strict=True) logger.debug("Received POST request with arguments: %s", args) return self.model_api.predict_using_model(args)
[docs]class ModelApi: """Class to plug a Data-Scientist-created model into. This class handles the heavy lifting of APIs for the model. The model is a delegate which inherits from (=implements) ModelInterface. It needs to provide a predict function. For details, see the documentation in the module `model_interface` """ def __init__(self, config, application, debug=False): """When initializing ModelApi, your model will be automatically retrieved from the model store based on the currently active configuration. Params: config: configuration dictionary to use application: flask application to use debug: use current prediction code instead of that of persisted model """ self.model_config = config["model"] model_store = resource.ModelStore(config) self.model_wrapper = self._load_model(model_store, self.model_config) if debug: # TODO: Hacky, should use model_actions functionality for a lot of API functionality instead of duplicating. # Create a fresh model object from current code and transplant existing contents m_cls = model_actions._get_model_class(config, cache=True) curr_model_wrapper = m_cls(contents=self.model_wrapper.contents) self.model_wrapper = curr_model_wrapper # Workaround (tensorflow has problem with spontaneously created threads such as with Flask): # https://kobkrit.com/tensor-something-is-not-an-element-of-this-graph-error-in-keras-on-flask-web-server-4173a8fe15e1 try: # Third-party imports import tensorflow as tf graph = tf.get_default_graph() self.model_wrapper.__graph = graph except Exception as e: logger.debug( 'Optional tensorflow/flask workaround for "<tensor> is not an element of this graph" problem' + "resulted in: %s", e, ) else: logger.info( 'Stored tensorflow model\'s graph - tensorflow/flask workaround for "<tensor> is not an element of this graph" problem' ) self.datasources, self.datasinks = self._init_datasources(config) logger.debug("Initializing RESTful API") api = Api(application) api_url = get_api_base_url(config) raml = _load_raml(config) res_normal, res_with_id, res_file = _get_resources(raml) if res_file or res_normal: resource_urls = {"query": None, "file": None} parsers = {"query": None, "file": None} if res_normal: logger.debug( "Adding query resource %s to api %s", res_normal.path, api_url, ) resource_urls["query"] = api_url + res_normal.path parsers["query"] = _create_request_parser(res_normal) if res_file: logger.debug( "Adding file-based resource %s to api %s", res_file.path, api_url, ) resource_urls["file"] = api_url + res_file.path parsers["file"] = _create_request_parser(res_file) if ( len(resource_urls) == 2 and resource_urls["query"] == resource_urls["file"] ): api.add_resource( QueryOrFileUploadResource, # QueryResource, resource_urls["query"], resource_class_kwargs={ "model_api_obj": self, "query_parser": parsers["query"], "file_parser": parsers["file"], }, ) else: for k, res_url in resource_urls.items(): if not res_url: continue api.add_resource( QueryOrFileUploadResource, # QueryResource, res_url, resource_class_kwargs={ "model_api_obj": self, "query_parser": parsers["query"] if k == "query" else None, "file_parser": parsers["file"] if k == "file" else None, }, ) if res_with_id: logger.debug( "Adding url-id resource %s to api %s", res_with_id.path, api_url, ) uri_param_name = res_with_id.uri_params[-1].name resource_url = ( api_url + res_with_id.parent.path + "/<string:some_resource_id>" ) parser = _create_request_parser(res_with_id) api.add_resource( GetByIdResource, resource_url, resource_class_kwargs={ "model_api_obj": self, "parser": parser, "id_name": uri_param_name, }, )
[docs] def predict_using_model(self, args_dict): logger.debug("Prediction input %s", dict(args_dict)) logger.info("Starting prediction") args_ordered_dict = OrderedDict(sorted(args_dict.items())) inner_model = self.model_wrapper.contents predict_args = [ self.model_config, self.datasources, self.datasinks, inner_model, args_ordered_dict, ] if hasattr(self.model_wrapper, "__graph"): with self.model_wrapper.__graph.as_default(): logger.info("Restored tensorflow model's graph") raw_output = self.model_wrapper.predict(*predict_args) else: raw_output = self.model_wrapper.predict(*predict_args) if ( self.model_wrapper.have_columns_been_ordered and not resource._order_columns_called ): logger.warning( "Model has been trained on ordered columns, but " "prediction does not call function order_columns." ) output = resource.to_plain_python_obj(raw_output) logger.debug("Prediction output %s", output) return output
@staticmethod def _init_datasources(config): logger.info("Initializing datasources...") dso, dsi = resource.create_data_sources_and_sinks( config, tags="predict" ) logger.info( "%s datasource(s) initialized: %s", len(dso), list(dso.keys()) ) logger.info( "%s datasink(s) initialized: %s", len(dsi), list(dsi.keys()) ) if config["api"].get("preload_datasources"): logger.info("Preloading datasources...") for ds in dso.values(): _ = ds.get_dataframe() return dso, dsi @staticmethod def _load_model(model_store, model_config): logger.info("Loading model...") model, meta = model_store.load_trained_model(model_config) logger.info( "Model loaded: {}, version: {}, created {}".format( meta["name"], meta["version"], meta["created"] ) ) return model
_pd_type_lookup = { "object": "string", "int64": "integer", "float64": "number", "bool": "boolean", "datetime64": "date", # https://github.com/raml-org/raml-spec/blob/master/versions/raml-08/raml-08.md#date-representations "category": "string" # plus enum: list(series.cat.categories) # RAML Types: any, object, array, union via type expression, # one of the following scalar types: number, boolean, string, # date-only, time-only, datetime-only, datetime, file, integer, or nil }
[docs]def generate_raml( complete_conf, data_source_name=None, data_frame=None, resource_name="mythings", ): # Stdlib imports from urllib.parse import quote_plus from .model_actions import _get_data_sources_and_sinks if data_source_name is not None: dso, _ = _get_data_sources_and_sinks( complete_conf, tags="", cache=True ) df = dso[data_source_name].get_dataframe() elif data_frame is not None: df = data_frame else: raise ValueError("Please provide a data_source_name or a data_frame") sample = df.sample(1).reset_index(drop=True) api_name = complete_conf["api"]["name"] api_version = _get_major_api_version(complete_conf) url_start = f"http://127.0.0.1:5000/{quote_plus(api_name)}/{api_version}/{resource_name}?" url_params = [] param_hints = """# can be false if optional, then provide a default here or be prepared to deal with missing values in your prediction #default: {} # only makes sense for required: false #minimum: 0 # optional, maximum, minLength and others are also possible.""" output = f"""#%RAML 0.8 --- title: Put title of your {api_name} API here baseUri: https://{{host}}/{api_name}/{{version}} version: {api_version} # new version only for API-breaking updates documentation: - title: Example section title content: | Example section contents /{resource_name}: # This should be a plural form of what you're predicting, e.g. "/client_activations" get: # We can support 'post' as well if needed (let us know if necessary) description: Briefly describe what {resource_name} exactly you get from this API queryParameters: # For all ways to specify parameters see https://github.com/raml-org/raml-spec/blob/master/versions/raml-08/raml-08.md#named-parameters""" for col_name in sample.columns: series = sample[col_name] type_str = str(series.dtype) raml_type = _pd_type_lookup[type_str] example = repr(series[0]) url_params += [quote_plus(col_name) + "=" + quote_plus(example)] illegal_chars = ':.,[]"\\ \n\t' quoted_col_name = ( f"'{col_name}'" if any(c in illegal_chars for c in col_name) else col_name ) cleaned_col_name = "".join( "_" if c in illegal_chars else c for c in col_name ) output += f""" {quoted_col_name}: displayName: Friendly Name of {cleaned_col_name} type: {raml_type} description: Description of what {cleaned_col_name} really is example: {example} required: true """ + param_hints.format( example ) if type_str == "category": output += "\n enum: " + str(list(series.cat.categories)) param_hints = "" output += f""" responses: 200: # OK body: application/json: # This schema is optional but should fit your example prediction result below: schema: | {{ "type": "object", "$schema": "http://json-schema.org/draft-03/schema", "id": "http://jsonschema.net", "required": true, "properties": {{ "prediction": {{ "type": "string", "required": true, "enum": ["Virginica", "Versicolor", "Setosa"] }}, "probability": {{ "type": "number", "required": false }} }} }} # Provide an example of your prediction result: example: | {{ "prediction": "Rainbows and Unicorns!" }} # /{{test_key}}: # This becomes relevant if you want the API user to provide e.g. ids for the model to look up data to predict for # get: # queryParameters: # hello: # description: some demo query parameter in addition to the uri param # type: string # required: true # enum: ['metric', 'imperial'] # #default: 42 # Example URL API call (to copy and paste into browser (e.g. Chrome) to test): # {url_start + '&'.join(url_params)} ######################################################################### # Above is printed an example RAML for your data to copy/paste into your # RAML file and adapt to fully define your API. # # What you should do now: # - Paste the above text into a <mymodel>.raml file (from the line '#%RAML 0.8') # - Find and remove your target variable(s) in the RAML and example URL # - Check and correct the resource name (currently /{resource_name}) # - Check the examples for sensitive data # - Fill in the titles and descriptions, adapt as needed # # Official RAML specification: https://github.com/raml-org/raml-spec/blob/master/versions/raml-08/raml-08.md ######################################################################### """ return output