Skip to content

Commit

Permalink
Feat: Onboard Mnist Dataset (#379)
Browse files Browse the repository at this point in the history
* feat: testing code

* feat: initial commit.

* feat: initial commit after code changes

* feat: initial commit after code changes

* feat: onboarding dataset mnist

* feat: onboarding mnist dataset, production ready

* feat: onboarding mnist dataset, production ready

* feat: changes done in pipeline yaml file, production ready
  • Loading branch information
gkodukula committed Jun 16, 2022
1 parent cdbca70 commit 9809935
Show file tree
Hide file tree
Showing 9 changed files with 479 additions and 0 deletions.
32 changes: 32 additions & 0 deletions datasets/mnist/infra/mnist_dataset.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


resource "google_storage_bucket" "mnist" {
name = "${var.bucket_name_prefix}-mnist"
force_destroy = true
location = "US"
uniform_bucket_level_access = true
lifecycle {
ignore_changes = [
logging,
]
}
}

output "storage_bucket-mnist-name" {
value = google_storage_bucket.mnist.name
}
28 changes: 28 additions & 0 deletions datasets/mnist/infra/provider.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/**
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


provider "google" {
project = var.project_id
impersonate_service_account = var.impersonating_acct
region = var.region
}

data "google_client_openid_userinfo" "me" {}

output "impersonating-account" {
value = data.google_client_openid_userinfo.me.email
}
26 changes: 26 additions & 0 deletions datasets/mnist/infra/variables.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


variable "project_id" {}
variable "bucket_name_prefix" {}
variable "impersonating_acct" {}
variable "region" {}
variable "env" {}
variable "iam_policies" {
default = {}
}

21 changes: 21 additions & 0 deletions datasets/mnist/pipelines/_images/run_csv_transform_kub/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

FROM python:3.8
ENV PYTHONUNBUFFERED True
COPY requirements.txt ./
RUN python3 -m pip install --no-cache-dir -r requirements.txt
WORKDIR /custom
COPY ./csv_transform.py .
CMD ["python3", "csv_transform.py"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import datetime
import logging
import os
import pathlib

import requests
from google.cloud import storage


def main(
source_url: str,
source_file: pathlib.Path,
target_file: pathlib.Path,
target_gcs_bucket: str,
target_gcs_path: str,
pipeline_name: str,
) -> None:

logging.info(
f"ML datasets {pipeline_name} process started at "
+ str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
)

logging.info("Creating 'files' folder")
pathlib.Path("./files").mkdir(parents=True, exist_ok=True)

logging.info(f"Downloading file from {source_url}... ")
download_file(source_url, source_file)

logging.info(
f"Uploading output file to.. gs://{target_gcs_bucket}/{target_gcs_path}"
)
upload_file_to_gcs(target_file, target_gcs_bucket, target_gcs_path)

logging.info(
f"ML datasets {pipeline_name} process completed at "
+ str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
)


def download_file(source_url: str, source_file: pathlib.Path) -> None:
logging.info(f"Downloading {source_url} into {source_file}")
r = requests.get(source_url, stream=True)
if r.status_code == 200:
with open(source_file, "wb") as f:
for chunk in r:
f.write(chunk)
else:
logging.error(f"Couldn't download {source_url}: {r.text}")


def upload_file_to_gcs(
file_path: pathlib.Path, target_gcs_bucket: str, target_gcs_path: str
) -> None:
if os.path.exists(file_path):
logging.info(
f"Uploading output file to gs://{target_gcs_bucket}/{target_gcs_path}"
)
storage_client = storage.Client()
bucket = storage_client.bucket(target_gcs_bucket)
blob = bucket.blob(target_gcs_path)
blob.upload_from_filename(file_path)
else:
logging.info(
f"Cannot upload file to gs://{target_gcs_bucket}/{target_gcs_path} as it does not exist."
)


if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)

main(
source_url=os.environ["SOURCE_URL"],
source_file=pathlib.Path(os.environ["SOURCE_FILE"]).expanduser(),
target_file=pathlib.Path(os.environ["TARGET_FILE"]).expanduser(),
target_gcs_bucket=os.environ["TARGET_GCS_BUCKET"],
target_gcs_path=os.environ["TARGET_GCS_PATH"],
pipeline_name=os.environ["PIPELINE_NAME"],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
google-cloud-storage
requests
26 changes: 26 additions & 0 deletions datasets/mnist/pipelines/dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

dataset:
name: mnist
friendly_name: mnist
description: ~
dataset_sources: ~
terms_of_use: ~

resources:
- type: storage_bucket
name: mnist
uniform_bucket_level_access: True
location: US
132 changes: 132 additions & 0 deletions datasets/mnist/pipelines/mnist/mnist_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from airflow import DAG
from airflow.providers.cncf.kubernetes.operators import kubernetes_pod

default_args = {
"owner": "Google",
"depends_on_past": False,
"start_date": "2022-06-10",
}


with DAG(
dag_id="mnist.mnist",
default_args=default_args,
max_active_runs=1,
schedule_interval="@weekly",
catchup=False,
default_view="graph",
) as dag:

# Task to copy `t10k-images-idx3-ubyte.gz` from MNIST Database to GCS
download_and_process_source_zip_file = kubernetes_pod.KubernetesPodOperator(
task_id="download_and_process_source_zip_file",
name="mnist",
namespace="composer",
service_account_name="datasets",
image_pull_policy="Always",
image="{{ var.json.mnist.container_registry.run_csv_transform_kub }}",
env_vars={
"SOURCE_URL": "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
"SOURCE_FILE": "files/t10k-images-idx3-ubyte.gz",
"TARGET_FILE": "files/t10k-images-idx3-ubyte.gz",
"TARGET_GCS_BUCKET": "{{ var.value.composer_bucket }}",
"TARGET_GCS_PATH": "data/mnist/mnist/t10k-images-idx3-ubyte.gz",
"PIPELINE_NAME": "mnist",
},
resources={
"request_memory": "2G",
"request_cpu": "200m",
"request_ephemeral_storage": "8G",
},
)

# Task to copy `train-images-idx3-ubyte.gz` from MNIST Database to GCS
download_and_process_source_zip_file_2 = kubernetes_pod.KubernetesPodOperator(
task_id="download_and_process_source_zip_file_2",
name="mnist",
namespace="composer",
service_account_name="datasets",
image_pull_policy="Always",
image="{{ var.json.mnist.container_registry.run_csv_transform_kub }}",
env_vars={
"SOURCE_URL": "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
"SOURCE_FILE": "files/train-images-idx3-ubyte.gz",
"TARGET_FILE": "files/train-images-idx3-ubyte.gz",
"TARGET_GCS_BUCKET": "{{ var.value.composer_bucket }}",
"TARGET_GCS_PATH": "data/mnist/mnist/train-images-idx3-ubyte.gz",
"PIPELINE_NAME": "mnist",
},
resources={
"request_memory": "2G",
"request_cpu": "200m",
"request_ephemeral_storage": "8G",
},
)

# Task to copy `train-labels-idx1-ubyte.gz` from MNIST Database to GCS
download_and_process_source_zip_file_3 = kubernetes_pod.KubernetesPodOperator(
task_id="download_and_process_source_zip_file_3",
name="mnist",
namespace="composer",
service_account_name="datasets",
image_pull_policy="Always",
image="{{ var.json.mnist.container_registry.run_csv_transform_kub }}",
env_vars={
"SOURCE_URL": "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
"SOURCE_FILE": "files/train-labels-idx1-ubyte.gz",
"TARGET_FILE": "files/train-labels-idx1-ubyte.gz",
"TARGET_GCS_BUCKET": "{{ var.value.composer_bucket }}",
"TARGET_GCS_PATH": "data/mnist/mnist/train-labels-idx1-ubyte.gz",
"PIPELINE_NAME": "mnist",
},
resources={
"request_memory": "2G",
"request_cpu": "200m",
"request_ephemeral_storage": "8G",
},
)

# Task to copy `t10k-labels-idx1-ubyte.gz` from MNIST Database to GCS
download_and_process_source_zip_file_4 = kubernetes_pod.KubernetesPodOperator(
task_id="download_and_process_source_zip_file_4",
name="mnist",
namespace="composer",
service_account_name="datasets",
image_pull_policy="Always",
image="{{ var.json.mnist.container_registry.run_csv_transform_kub }}",
env_vars={
"SOURCE_URL": "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
"SOURCE_FILE": "files/t10k-labels-idx1-ubyte.gz",
"TARGET_FILE": "files/t10k-labels-idx1-ubyte.gz",
"TARGET_GCS_BUCKET": "{{ var.value.composer_bucket }}",
"TARGET_GCS_PATH": "data/mnist/mnist/t10k-labels-idx1-ubyte.gz",
"PIPELINE_NAME": "mnist",
},
resources={
"request_memory": "2G",
"request_cpu": "200m",
"request_ephemeral_storage": "8G",
},
)

(
download_and_process_source_zip_file
>> download_and_process_source_zip_file_2
>> download_and_process_source_zip_file_3
>> download_and_process_source_zip_file_4
)
Loading

0 comments on commit 9809935

Please sign in to comment.