Source code for examples.records_datasource
import logging
from typing import Dict, Optional
from mllaunchpad.resource import DataSource, get_user_pw
logger = logging.getLogger(__name__)
try:
import records
except ModuleNotFoundError:
logger.warning("Please install the Records package to be able to use RecordsDbDataSource.")
[docs]class RecordsDbDataSource(DataSource):
"""DataSource for a bunch of relational database types:
RedShift, Postgres, MySQL, SQLite, Oracle, Microsoft SQL.
EXPERIMENTAL
See :attr:`serves` for the available types.
Creates a long-living connection on initialization.
Configuration::
dbms:
# ... (other connections)
my_connection: # NOTE: You can use the same connection for several datasources and datasinks
type: oracle # see attribute serves for available types
host: host.example.com
port: 1251 # optional
user_var: MY_USER_ENV_VAR
password_var: MY_PW_ENV_VAR # optional
service_name: servicename.example.com # optional
options: {} # used as **kwargs when initializing the DB connection
# ...
datasources:
# ... (other datasources)
my_datasource:
type: dbms.my_connection
query: SELECT * FROM users.my_table where id = :id # fill `:params` by calling `get_dataframe` with a `dict`
expires: 0 # generic parameter, see documentation on DataSources
tags: [train] # generic parameter, see documentation on DataSources and DataSinks
"""
serves = [
"dbms.oracle",
"dbms.redshift",
"dbms.postgres",
"dbms.mysql",
"dbms.sqlite",
"dbms.ms_sql",
]
def __init__(self, identifier: str, datasource_config: Dict, dbms_config: Optional[Dict]):
super().__init__(identifier, datasource_config)
self.dbms_config = dbms_config
logger.info(
"Establishing Records database connection for datasource {}...".format(
self.id
)
)
# if "connect" not in dbms_config:
# raise ValueError(f'No connection string (property "connect") in datasource {self.id} config')
dbtype = dbms_config["type"]
user, pw = get_user_pw(dbms_config["user_var"], dbms_config["password_var"])
host = dbms_config["host"]
port = ":" + str(dbms_config["port"]) if "port" in dbms_config else ""
service_name = (
"/?service_name=" + dbms_config["service_name"]
if "service_name" in dbms_config
else ""
)
kwargs = dbms_config.get("options", {})
connection_string = f"{dbtype}://{user}:{pw}@{host}{port}{service_name}"
self.db = records.Database(connection_string, **kwargs)
def get_dataframe(self, params=None, chunksize=None):
"""Get data as a pandas dataframe.
Example::
data_sources["my_datasource"].get_dataframe({"id": 387})
:param params: Query parameters to fill in query (e.g. `:id` with value 387)
:type params: optional dict
:param chunksize: Currently not implemented
:type chunksize: optional bool
:return: DataFrame object, possibly cached according to expires-config
"""
if chunksize:
raise NotImplementedError("Buffered reading not supported yet")
# the resulting `rows` of a query provides a nice way to do this, though
query = self.config["query"]
params = params or {}
logger.debug(
"Fetching query {} with params {}...".format(
query, params
)
)
rows = self.db.query(query, fetchall=True, **params)
df = rows.export("df")
return df
def get_raw(self, params=None, chunksize=None):
"""Not implemented.
:param params: Query parameters to fill in query (e.g. `:id` with value 387)
:type params: optional dict
:param chunksize: Currently not implemented
:type chunksize: optional bool
:raises NotImplementedError:
:return: Nothing, always raises NotImplementedError
"""
raise NotImplementedError("Raw data not supported")
# def __del__(self):
# if hasattr(self, "db"):
# self.db.close()