gpicciuca commited on
Commit
058f1d9
·
0 Parent(s):

First commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This Dockerfile serves as the build file for Huggingface Spaces
2
+ FROM huggingface/transformers-pytorch-gpu
3
+
4
+ ARG ML_APP_LISTEN_PORT=7860
5
+ ARG ML_MLFLOW_ENDPOINT
6
+ ARG ML_HF_ACCESS_TOKEN
7
+
8
+ ENV APP_LISTEN_PORT=${ML_APP_LISTEN_PORT}
9
+ ENV MLFLOW_ENDPOINT=${ML_MLFLOW_ENDPOINT}
10
+ ENV HF_ACCESS_TOKEN=${ML_HF_ACCESS_TOKEN}
11
+
12
+ RUN apt-get update
13
+ RUN /usr/bin/python3 -m pip install uvicorn fastapi mlflow huggingface_hub httpx
14
+
15
+ WORKDIR /app
16
+
17
+ COPY ./app /app
18
+
19
+ EXPOSE ${ML_APP_LISTEN_PORT}
20
+
21
+ ENTRYPOINT [ "/usr/bin/python3", "/app/main.py" ]
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SocialMedia Sentiment Analysis
3
+ emoji: 🐳
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
app/config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def is_test_mode():
4
+ """
5
+ Checks whether test mode is enabled or not.
6
+
7
+ Returns:
8
+ bool: True if test mode is enabled, false otherwise
9
+ """
10
+ return os.environ.get("TEST_MODE", None) == "1"
11
+
12
+ def enable_test_mode():
13
+ os.environ["TEST_MODE"] = "1"
app/main.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, BackgroundTasks
2
+ from fastapi.responses import JSONResponse, HTMLResponse
3
+ from pydantic import BaseModel
4
+ import uvicorn
5
+ from uvicorn.config import logger
6
+ import os
7
+ import argparse
8
+ from tasks.training import TrainingTask
9
+ from config import enable_test_mode
10
+
11
+ app = FastAPI()
12
+
13
+ @app.post("/train/start", response_class=JSONResponse)
14
+ async def start_model_training(background_tasks: BackgroundTasks):
15
+ """
16
+ Endpoint on which a request can be sent to start model re-training,
17
+ if there's no training task currently running.
18
+ The task will be carried out in background and its status can be
19
+ polled via /train/get_state.
20
+
21
+ Args:
22
+ background_tasks (BackgroundTasks): BG Tasks scheduler provided by FastAPI
23
+
24
+ Returns:
25
+ dict: A dictionary containing a message of the outcome for the request.
26
+ """
27
+ if not TrainingTask.has_instance():
28
+ background_tasks.add_task(TrainingTask.get_instance())
29
+
30
+ return {
31
+ "message": "Model training was scheduled and will begin shortly.",
32
+ }
33
+
34
+ return {
35
+ "message": "A training instance is already running.",
36
+ }
37
+
38
+ @app.post("/train/get_state", response_class=JSONResponse)
39
+ async def poll_model_training_state():
40
+ """
41
+ Checks if there is currently a training task ongoing.
42
+ If so, returns whether it's done and/or if an error occurred.
43
+ Otherwise if no instance is running, returns only a message.
44
+
45
+ Returns:
46
+ dict: Dictionary containing either done/error or message.
47
+ """
48
+ if TrainingTask.has_instance():
49
+ train_instance : TrainingTask = TrainingTask.get_instance()
50
+ is_done = train_instance.is_done()
51
+ has_error = train_instance.has_error()
52
+
53
+ if is_done:
54
+ TrainingTask.clear_instance()
55
+
56
+ return {
57
+ "done": is_done,
58
+ "error": has_error,
59
+ }
60
+
61
+ return {
62
+ "message": "No training instance running!",
63
+ }
64
+
65
+ class InferenceRequest(BaseModel):
66
+ """
67
+ Provides a model/schema for the accepted request body for incoming
68
+ inference requests.
69
+ """
70
+ messages: list[str]
71
+
72
+ @app.post("/inference", response_class=JSONResponse)
73
+ async def inference(data: InferenceRequest):
74
+ """
75
+
76
+
77
+ Args:
78
+ data (InferenceRequest): Structure containing a list of
79
+ messages that shall be evaluated
80
+
81
+ Returns:
82
+ json: A json list containing the sentiment analysis for each message.
83
+ Each element consists of a dictionary with the following keys:
84
+ positive, neutral, negative
85
+ """
86
+
87
+ from tasks.inference import infer_task
88
+ return infer_task.predict(data.messages)
89
+
90
+ @app.get("/", response_class=HTMLResponse)
91
+ async def root():
92
+ """
93
+ The root endpoint for our hosted application. Only shows a message
94
+ showing that it's up and running.
95
+
96
+ Returns:
97
+ str: A html response containing a hello world-like string
98
+ """
99
+ return "Hi there! It's a nice blank page, isn't it?"
100
+
101
+ if __name__ == "__main__":
102
+ """
103
+ Entrypoint for the application executed via command-line.
104
+ It accepts an optional argument "--test" to enable the test mode.
105
+ """
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("test", nargs="?", default="no")
108
+ args = parser.parse_args()
109
+
110
+ if args.test == "yes":
111
+ enable_test_mode()
112
+
113
+ config = uvicorn.Config("main:app", host="0.0.0.0", port=int(os.environ["APP_LISTEN_PORT"]), log_level="debug")
114
+ server = uvicorn.Server(config)
115
+ server.run()
app/tasks/__init__.py ADDED
File without changes
app/tasks/inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from uvicorn.config import logger
2
+ from transformers import AutoModelForSequenceClassification
3
+ from transformers import AutoTokenizer, AutoConfig
4
+ import numpy as np
5
+ import mlflow
6
+ import os
7
+ import time
8
+ from scipy.special import softmax
9
+
10
+ # HuggingFace Model to be used for inferencing
11
+ MODEL = f"cardiffnlp/twitter-roberta-base-sentiment-latest"
12
+
13
+ class InferenceTask:
14
+
15
+ def __init__(self):
16
+ self.clear()
17
+ self.load_model()
18
+
19
+ def load_model(self):
20
+ try:
21
+ self.__tokenizer = AutoTokenizer.from_pretrained(MODEL)
22
+ self.__config = AutoConfig.from_pretrained(MODEL)
23
+ self.__model = AutoModelForSequenceClassification.from_pretrained(MODEL)
24
+ self.__is_loaded = True
25
+ except Exception as ex:
26
+ logger.error("Failed to load inference model: {ex}")
27
+ self.clear()
28
+ return False
29
+
30
+ return True
31
+
32
+ def clear(self):
33
+ self.__is_loaded = False
34
+ self.__tokenizer = None
35
+ self.__config = None
36
+ self.__model = None
37
+
38
+ def is_loaded(self):
39
+ return self.__is_loaded
40
+
41
+ def predict(self, messages: list[str]):
42
+ if len(messages) == 0:
43
+ return None
44
+
45
+ if not self.is_loaded() and not self.load_model():
46
+ return None
47
+
48
+ mlflow.set_tracking_uri(os.environ["MLFLOW_ENDPOINT"])
49
+ mlflow.set_experiment("Sentiment Analysis")
50
+
51
+ with mlflow.start_run() as run:
52
+ preprocessed_messages = self.__preprocess(messages)
53
+ labelized_scores = []
54
+
55
+ for message in preprocessed_messages:
56
+ encoded_input = self.__tokenizer(message, return_tensors='pt', padding="longest")
57
+ output = self.__model(**encoded_input)
58
+ scores = output[0][0].detach().numpy()
59
+ scores = softmax(scores)
60
+ scores = self.__labelize(scores)
61
+ labelized_scores.append(scores)
62
+
63
+ mean_sentiment = self.__calculate_mean_sentiment(labelized_scores)
64
+ mean_sentiment["samples"] = len(labelized_scores)
65
+ logger.info(mean_sentiment)
66
+
67
+ mlflow.log_metrics(mean_sentiment, step=int(time.time()))
68
+
69
+ return labelized_scores
70
+
71
+ def __calculate_mean_sentiment(self, labelized_scores: list):
72
+ total_samples = float(len(labelized_scores))
73
+
74
+ mean_sentiment = {
75
+ "positive": 0.0,
76
+ "neutral": 0.0,
77
+ "negative": 0.0,
78
+ }
79
+
80
+ for score in labelized_scores:
81
+ mean_sentiment["positive"] += score["positive"]
82
+ mean_sentiment["neutral"] += score["neutral"]
83
+ mean_sentiment["negative"] += score["negative"]
84
+
85
+ mean_sentiment["positive"] /= total_samples
86
+ mean_sentiment["neutral"] /= total_samples
87
+ mean_sentiment["negative"] /= total_samples
88
+
89
+ return mean_sentiment
90
+
91
+ # Preprocess text (username and link placeholders)
92
+ def __preprocess(self, messages: list[str]):
93
+ msg_list = []
94
+ for message in messages:
95
+ new_message = []
96
+ for t in message.split(" "):
97
+ t = '@user' if t.startswith('@') and len(t) > 1 else t
98
+ t = 'http' if t.startswith('http') else t
99
+ new_message.append(t)
100
+ msg_list.append(" ".join(new_message))
101
+ return msg_list
102
+
103
+ def __labelize(self, scores):
104
+ output = {}
105
+ ranking = np.argsort(scores)
106
+ ranking = ranking[::-1]
107
+ for i in range(scores.shape[0]):
108
+ l = self.__config.id2label[ranking[i]]
109
+ s = float(scores[ranking[i]])
110
+ output[l] = s
111
+ return output
112
+
113
+ # Preload a global instance so that inference can be
114
+ # executed immediately when requested
115
+ infer_task = InferenceTask()
app/tasks/training.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ from uvicorn.config import logger
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ Trainer,
9
+ TrainingArguments,
10
+ pipeline,
11
+ )
12
+ from huggingface_hub import login, logout
13
+
14
+ import os
15
+ import mlflow
16
+ from tasks.inference import infer_task
17
+ from config import is_test_mode
18
+
19
+ """
20
+ Documentation:
21
+ - https://huggingface.co/docs/transformers/en//training
22
+ - https://mlflow.org/docs/latest/llms/transformers/tutorials/fine-tuning/transformers-fine-tuning
23
+ """
24
+
25
+ MODEL = "cardiffnlp/twitter-roberta-base-sentiment-latest"
26
+ DATASET = "zeroshot/twitter-financial-news-sentiment"
27
+ HF_DEST_REPO = "financial-twitter-roberta-sentiment"
28
+
29
+ RNG_SEED = 22
30
+
31
+ class TrainingTask:
32
+
33
+ TRAINING_TASK_INST_SINGLETON = None
34
+
35
+ def __init__(self):
36
+ self.__is_done = False
37
+ self.__has_error = False
38
+
39
+ self.__train_dataset = None
40
+ self.__test_dataset = None
41
+ self.__tokenizer = None
42
+ self.__train_tokenized = None
43
+ self.__test_tokenized = None
44
+ self.__model = None
45
+ self.__trainer = None
46
+ self.__run_id = None
47
+
48
+ @staticmethod
49
+ def has_instance():
50
+ return TrainingTask.TRAINING_TASK_INST_SINGLETON is not None
51
+
52
+ @staticmethod
53
+ def get_instance():
54
+ if TrainingTask.TRAINING_TASK_INST_SINGLETON is None:
55
+ TrainingTask.TRAINING_TASK_INST_SINGLETON = TrainingTask()
56
+
57
+ return TrainingTask.TRAINING_TASK_INST_SINGLETON
58
+
59
+ @staticmethod
60
+ def clear_instance():
61
+ del TrainingTask.TRAINING_TASK_INST_SINGLETON
62
+ TrainingTask.TRAINING_TASK_INST_SINGLETON = None
63
+
64
+ def has_error(self):
65
+ return self.__has_error
66
+
67
+ def is_done(self):
68
+ return self.__is_done
69
+
70
+ def __call__(self, *args, **kwds):
71
+ self.__has_error = False
72
+ self.__is_done = False
73
+
74
+ if is_test_mode():
75
+ # Simulate a successful training run in test mode
76
+ self.__has_error = False
77
+ self.__is_done = True
78
+ return
79
+
80
+ login(token=os.environ["HF_ACCESS_TOKEN"])
81
+
82
+ try:
83
+ self.__load_datasets()
84
+ self.__tokenize()
85
+ self.__load_model()
86
+ self.__train()
87
+ self.__evaluate()
88
+ self.__deploy()
89
+ except Exception as ex:
90
+ logger.error(f"Error during training: {ex}")
91
+ self.__has_error = True
92
+ finally:
93
+ self.__is_done = True
94
+
95
+ logout()
96
+
97
+ self.__reload_inference_model()
98
+
99
+ def __load_datasets(self):
100
+ # Load the dataset.
101
+ dataset = load_dataset(DATASET)
102
+
103
+ # Split train/test by an 8/2 ratio.
104
+ dataset_train_test = dataset["train"].train_test_split(test_size=0.2)
105
+ self.__train_dataset = dataset_train_test["train"]
106
+ self.__test_dataset = dataset_train_test["test"]
107
+
108
+ # Swap labels so that they match what the model actually expects
109
+ # The model expects {0: positive, 1: neutral, 2: negative}
110
+ # But the dataset uses {0: positive, 1: negative, 2: neutral}
111
+ # So here we just flip 1<->2 to remain consistent
112
+ def label_filter(row):
113
+ row["label"] = { 0: 0, 1: 2, 2: 1 }[row["label"]]
114
+ return row
115
+
116
+ self.__train_dataset = self.__train_dataset.map(label_filter)
117
+ self.__test_dataset = self.__test_dataset.map(label_filter)
118
+
119
+ def __tokenize(self):
120
+ # Load the tokenizer for the model.
121
+ self.__tokenizer = AutoTokenizer.from_pretrained(MODEL)
122
+
123
+ def tokenize_function(examples):
124
+ # Pad/truncate each text to 512 tokens. Enforcing the same shape
125
+ # could make the training faster.
126
+ return self.__tokenizer(
127
+ examples["text"],
128
+ padding="max_length",
129
+ truncation=True,
130
+ max_length=256,
131
+ )
132
+
133
+ # Tokenize the train and test datasets
134
+ self.__train_tokenized = self.__train_dataset.map(tokenize_function)
135
+ self.__train_tokenized = self.__train_tokenized.remove_columns(["text"]).shuffle(seed=RNG_SEED)
136
+
137
+ self.__test_tokenized = self.__test_dataset.map(tokenize_function)
138
+ self.__test_tokenized = self.__test_tokenized.remove_columns(["text"]).shuffle(seed=RNG_SEED)
139
+
140
+ def __load_model(self):
141
+ # Set the mapping between int label and its meaning.
142
+ id2label = {0: "Bearish", 1: "Neutral", 2: "Bullish"}
143
+ label2id = {"Bearish": 0, "Neutral": 1, "Bullish": 2}
144
+
145
+ # Acquire the model from the Hugging Face Hub, providing label and id mappings so that both we and the model can 'speak' the same language.
146
+ self.__model = AutoModelForSequenceClassification.from_pretrained(
147
+ MODEL,
148
+ num_labels=3,
149
+ label2id=label2id,
150
+ id2label=id2label,
151
+ )
152
+
153
+ def __train(self):
154
+ # Define the target optimization metric
155
+ metric = evaluate.load("accuracy")
156
+
157
+ # Define a function for calculating our defined target optimization metric during training
158
+ def compute_metrics(eval_pred):
159
+ logits, labels = eval_pred
160
+ predictions = np.argmax(logits, axis=-1)
161
+ return metric.compute(predictions=predictions, references=labels)
162
+
163
+ # Checkpoints will be output to this `training_output_dir`.
164
+ training_output_dir = "/tmp/sentiment_trainer"
165
+ training_args = TrainingArguments(
166
+ output_dir=training_output_dir,
167
+ eval_strategy="epoch",
168
+ per_device_train_batch_size=8,
169
+ per_device_eval_batch_size=8,
170
+ logging_steps=8,
171
+ num_train_epochs=3,
172
+ )
173
+
174
+ # Instantiate a `Trainer` instance that will be used to initiate a training run.
175
+ self.__trainer = Trainer(
176
+ model=self.__model,
177
+ args=training_args,
178
+ train_dataset=self.__train_tokenized,
179
+ eval_dataset=self.__test_tokenized,
180
+ compute_metrics=compute_metrics,
181
+ )
182
+
183
+ mlflow.set_tracking_uri(os.environ["MLFLOW_ENDPOINT"])
184
+ mlflow.set_experiment("Sentiment Classifier Training")
185
+
186
+ with mlflow.start_run() as run:
187
+ self.__run_id = run.info.run_id
188
+ self.__trainer.train()
189
+
190
+ def __evaluate(self):
191
+ tuned_pipeline = pipeline(
192
+ task="text-classification",
193
+ model=self.__trainer.model,
194
+ batch_size=8,
195
+ tokenizer=self.__tokenizer,
196
+ device="cpu", # or cuda
197
+ )
198
+
199
+ quick_check = (
200
+ "I have a question regarding the project development timeline and allocated resources; "
201
+ "specifically, how certain are you that John and Ringo can work together on writing this next song? "
202
+ "Do we need to get Paul involved here, or do you truly believe, as you said, 'nah, they got this'?"
203
+ )
204
+
205
+ result = tuned_pipeline(quick_check)
206
+ logger.debug("Test evaluation of fine-tuned model: %s %.6f" % (result[0]["label"], result[0]["score"]))
207
+
208
+ # Define a set of parameters that we would like to be able to flexibly override at inference time, along with their default values
209
+ model_config = {"batch_size": 8}
210
+
211
+ # Infer the model signature, including a representative input, the expected output, and the parameters that we would like to be able to override at inference time.
212
+ signature = mlflow.models.infer_signature(
213
+ ["This is a test!", "And this is also a test."],
214
+ mlflow.transformers.generate_signature_output(
215
+ tuned_pipeline, ["This is a test response!", "So is this."]
216
+ ),
217
+ params=model_config,
218
+ )
219
+
220
+ # Log the pipeline to the existing training run
221
+ with mlflow.start_run(run_id=self.__run_id):
222
+ model_info = mlflow.transformers.log_model(
223
+ transformers_model=tuned_pipeline,
224
+ artifact_path="fine_tuned",
225
+ signature=signature,
226
+ input_example=["Pass in a string", "And have it mark as spam or not."],
227
+ model_config=model_config,
228
+ )
229
+
230
+ # Load our saved model in the native transformers format
231
+ loaded = mlflow.transformers.load_model(model_uri=model_info.model_uri)
232
+
233
+ # Define a test example that we expect to be classified as spam
234
+ validation_text = (
235
+ "Want to learn how to make MILLIONS with no effort? Click HERE now! See for yourself! Guaranteed to make you instantly rich! "
236
+ "Don't miss out you could be a winner!"
237
+ )
238
+
239
+ # validate the performance of our fine-tuning
240
+ loaded(validation_text)
241
+
242
+ def __deploy(self):
243
+ self.__trainer.push_to_hub(HF_DEST_REPO)
244
+
245
+ def __reload_inference_model(self):
246
+ infer_task.load_model()
app/test/__init__.py ADDED
File without changes
app/test/fixture.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ from main import app
4
+ import os
5
+
6
+ @pytest.fixture()
7
+ def app_client():
8
+ """
9
+ Barebone test fixture that initializes a FastAPI TestClient
10
+ which can be used to test all endpoints provided by the application.
11
+
12
+ Yields:
13
+ TestClient: A client hosting the whole application so that it
14
+ can be accessed and controlled programmatically.
15
+ """
16
+ os.environ["TEST_MODE"] = "1" # Turns off actual model training
17
+
18
+ client = TestClient(app)
19
+ yield client
app/test/test_inferencing_endpoint.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fixture import app_client
2
+ import json
3
+
4
+ def test_inference_endpoint(app_client):
5
+ """
6
+ Tests the output of the inference endpoint of the application.
7
+
8
+ Given:
9
+ - Payload with valid list of messages
10
+ When:
11
+ - POST Request sent to inference endpoint
12
+ Then:
13
+ - Expect message to be classified as positive
14
+ """
15
+ response = app_client.post(
16
+ "/inference",
17
+ headers={
18
+ "Content-Type": "application/json",
19
+ },
20
+ json={
21
+ "messages": [
22
+ "BTC is going to skyrocket!",
23
+ ],
24
+ }
25
+ )
26
+
27
+ assert response.status_code == 200
28
+ output = json.loads(response.content)
29
+ assert isinstance(output, list)
30
+ assert len(output) == 1
31
+ assert output[0]["positive"] > output[0]["negative"] and output[0]["positive"] > output[0]["neutral"]
32
+
33
+
34
+ def test_inference_endpoint_with_wrong_payload(app_client):
35
+ """
36
+ Tests the output of the inference endpoint of the application with an
37
+ invalid payload.
38
+ This should yield a 422 status error as FastAPI will not be able
39
+ to translate the payload into the InferenceRequest model.
40
+
41
+ Given:
42
+ - Payload with wrong message key
43
+ When:
44
+ - POST Request sent to inference endpoint
45
+ Then:
46
+ - Expect 422 status code
47
+ """
48
+ response = app_client.post(
49
+ "/inference",
50
+ headers={
51
+ "Content-Type": "application/json",
52
+ },
53
+ json={
54
+ "msgs": [
55
+ "BTC is going to skyrocket!",
56
+ ],
57
+ }
58
+ )
59
+
60
+ assert response.status_code == 422 # Unprocessable entity
61
+
62
+ def test_inference_endpoint_with_no_prompt(app_client):
63
+ """
64
+ Tests the output of the inference endpoint of the application
65
+ when a valid payload is provided but with no actual messages.
66
+
67
+ Given:
68
+ - Payload without any messages
69
+ When:
70
+ - POST Request sent to inference endpoint
71
+ Then:
72
+ - Expect no error and correct format
73
+ """
74
+ response = app_client.post(
75
+ "/inference",
76
+ headers={
77
+ "Content-Type": "application/json",
78
+ },
79
+ json={
80
+ "messages": [],
81
+ }
82
+ )
83
+
84
+ assert response.status_code == 200
85
+ output = json.loads(response.content)
86
+ assert output == None
app/test/test_training_endpoint.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fixture import app_client
2
+ import json
3
+ import time
4
+
5
+ def test_training_endpoint(app_client):
6
+ """
7
+ Checks whether the training endpoint correctly receives, starts
8
+ and clears a training task instance.
9
+
10
+ Given:
11
+ - Launched training instance
12
+
13
+ When:
14
+ - State polled multiple times
15
+
16
+ Then:
17
+ - Expect state returned on first poll and instance gone on second poll
18
+ """
19
+ response = app_client.post("/train/start")
20
+ assert response.status_code == 200
21
+ output : dict = json.loads(response.content)
22
+ assert len(output.keys()) == 1
23
+ assert output["message"] == "Model training was scheduled and will begin shortly."
24
+
25
+ time.sleep(5)
26
+
27
+ response = app_client.post("/train/get_state")
28
+ assert response.status_code == 200
29
+ output : dict = json.loads(response.content)
30
+ assert len(output.keys()) == 2
31
+ assert output["done"]
32
+ assert not output["error"]
33
+
34
+ time.sleep(1)
35
+
36
+ response = app_client.post("/train/get_state")
37
+ assert response.status_code == 200
38
+ output : dict = json.loads(response.content)
39
+ assert len(output.keys()) == 1
40
+ assert output["message"] == "No training instance running!"
docker-compose.test.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ services:
2
+ # Override parameters of this service
3
+ model_runner:
4
+ # Sets the entrypoint so that pytest is executed
5
+ entrypoint: ["pytest"]
docker-compose.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ #
3
+ model_runner:
4
+ build:
5
+ context: .
6
+ dockerfile: Dockerfile
7
+
8
+ environment:
9
+ - APP_LISTEN_PORT=${APP_LISTEN_PORT}
10
+ - MLFLOW_ENDPOINT=${MLFLOW_ENDPOINT}
11
+ - HF_ACCESS_TOKEN=${HF_ACCESS_TOKEN}
12
+
13
+ container_name: mlrunner
14
+ restart: on-failure
15
+ ports:
16
+ - "${APP_LISTEN_PORT}:${APP_LISTEN_PORT}"
17
+ volumes:
18
+ - "./app:/app"
19
+ entrypoint: ["/usr/bin/python3", "/app/main.py"]
20
+ networks:
21
+ - airflow_tracking_network
22
+
23
+ networks:
24
+ airflow_tracking_network:
25
+ external: true