forked from Shazid08/Dis-Easify_Deploy_Source
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathml_service.py
231 lines (197 loc) · 7.36 KB
/
ml_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Union
import numpy as np
from joblib import load
import tensorflow as tf
import os
import logging
import base64
from PIL import Image
import io
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Disease Prediction API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load all models at startup
models = {}
# Joblib models configuration
joblib_models = {
'breast_cancer': 'breast_cancer_rfc_model.joblib',
'diabetes': 'diabetes_dtc_model.joblib',
'disease_gnb': 'disease_gnb_model.joblib',
'heart': 'heart_dtc_model1.joblib',
}
# TensorFlow models configuration with expected shapes
tf_models = {
'pneumonia': {
'file': 'pneumonia2.h5',
'input_shape': (36, 36, 1) # Updated to match the error message
}
}
# Load joblib models
for model_name, model_file in joblib_models.items():
try:
logger.info(f"Attempting to load {model_name} from /savedModels/{model_file}")
models[model_name] = load(f'/savedModels/{model_file}')
logger.info(f"Loaded {model_name} successfully")
except Exception as e:
logger.error(f"Error loading {model_name}: {e}")
# Load TensorFlow models
for model_name, model_info in tf_models.items():
try:
logger.info(f"Attempting to load {model_name} from /savedModels/{model_info['file']}")
# Direct model loading without rebuilding
model_path = f'/savedModels/{model_info["file"]}'
# Verify file exists
if not os.path.exists(model_path):
logger.error(f"Model file does not exist: {model_path}")
raise FileNotFoundError(f"Model file not found: {model_path}")
# Load the model directly
model = tf.keras.models.load_model(
model_path,
compile=False
)
# Log model details
logger.info(f"Model Input Shape: {model.input_shape}")
logger.info(f"Model Output Shape: {model.output_shape}")
# Compile the model
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
models[model_name] = model
logger.info(f"Loaded {model_name} successfully")
except Exception as e:
logger.error(f"Error loading {model_name}: {e}")
# Compile the model
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
models[model_name] = model
logger.info(f"Loaded {model_name} successfully")
except Exception as e:
logger.error(f"Error loading {model_name}: {e}")
# Try the next alternative
try:
logger.info("Attempting to load pneumonia3.h5 as final alternative")
model = tf.keras.models.load_model(f'/savedModels/pneumonia3.h5', compile=False)
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
models[model_name] = model
logger.info(f"Loaded {model_name} (using pneumonia3.h5) successfully")
except Exception as final_error:
logger.error(f"Failed to load any pneumonia model: {final_error}")
class PredictionRequest(BaseModel):
model_name: str
features: List[Union[float, int]]
class ImagePredictionRequest(BaseModel):
model_name: str
image: str # base64 encoded image
@app.post("/predict")
async def predict(request: PredictionRequest):
"""Make predictions using the specified model."""
if request.model_name not in models:
raise HTTPException(
status_code=404,
detail=f"Model {request.model_name} not found. Available models: {list(models.keys())}"
)
try:
model = models[request.model_name]
features = np.array(request.features)
# Handle different model types
if request.model_name in tf_models:
# For pneumonia model
input_shape = tf_models[request.model_name]['input_shape']
features = features.reshape((-1,) + input_shape)
# Ensure values are between 0 and 1
features = features / 255.0 if features.max() > 1.0 else features
prediction = model.predict(features)
else:
prediction = model.predict(features.reshape(1, -1))
return {
"model_name": request.model_name,
"prediction": prediction.tolist(),
"status": "success"
}
except ValueError as ve:
raise HTTPException(status_code=400, detail=f"Invalid input shape: {str(ve)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
@app.post("/predict_image")
async def predict_image(request: ImagePredictionRequest):
"""Endpoint for image-based predictions"""
if request.model_name not in models:
raise HTTPException(
status_code=404,
detail=f"Model {request.model_name} not found"
)
try:
# Decode base64 image
image_data = base64.b64decode(request.image)
image = Image.open(io.BytesIO(image_data))
# Convert to grayscale and resize
image = image.convert('L')
image = image.resize((36, 36))
# Convert to numpy array and normalize
image_array = np.array(image)
image_array = image_array / 255.0
image_array = image_array.reshape(1, 36, 36, 1)
# Make prediction
prediction = models[request.model_name].predict(image_array)
confidence = float(prediction[0][0]) * 100
return {
"prediction": int(prediction[0][0] > 0.5),
"confidence": round(confidence, 2)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"available_models": list(models.keys()),
"model_count": len(models),
"loaded_models": {
name: "loaded" if name in models else "failed"
for name in list(joblib_models.keys()) + list(tf_models.keys())
}
}
@app.get("/models")
async def list_models():
"""List all available models and their details."""
return {
"total_models": len(models),
"available_models": list(models.keys()),
"models_info": {
"joblib_models": {
name: "loaded" for name in joblib_models.keys() if name in models
},
"tensorflow_models": {
name: {
"status": "loaded" if name in models else "failed",
"expected_input_shape": model_info["input_shape"] if name in models else None
}
for name, model_info in tf_models.items()
}
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)