receive_with_middleware
에서 처리하고, Response는 send_with_middleware
에서 처리할 수 있도록 구조화했습니다.class ApplicationJsonMiddleware:
def __init__(
self,
app: ASGIApp
) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
responder = _ApplicationJsonResponder(
self.app
)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
class _ApplicationJsonResponder:
def __init__(
self,
app: ASGIApp
) -> None:
self.app = app
self.receive: Receive = unattached_receive
self.send: Send = unattached_send
self.initial_message: Message = {}
self.start_message: Message = {}
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, self.receive_with_middleware, self.send_with_middleware)
async def receive_with_middleware(self) -> Message:
message = await self.receive()
while more_body:
message = await self.receive()
if message and message['body']:
body = body + message['body'].decode('utf-8')
more_body = message.get("more_body", False)
if not body:
return message
return message
async def send_with_middleware(self, message: Message) -> None:
if (self.path == '/docs') or (self.path == '/openapi.json') :
await self.send(message)
return
if message["type"] == "http.response.start":
headers = Headers(raw=message["headers"])
if headers["content-type"] != "application/json":
self.should_send_json = False
await self.send(message)
return
self.start_message = message
elif message["type"] == "http.response.body":
body = get(message, 'body')
if not self.should_send_json:
await self.send(message)
return
await self.send(self.start_message)
await self.send(message)
return
async def unattached_receive() -> Message:
raise RuntimeError("receive awaitable not set") # pragma: no cover
async def unattached_send(message: Message) -> None:
raise RuntimeError("send awaitable not set") # pragma: no cover
obj = json.loads(body)
k = get(obj, 'k')
# 암호화 처리에 대한 부분을 여기서 진행하도록 한다.
if k:
# 암호화를 풀어서 스트링을 얻고 얻은 값을 obj에 넣어서 진행해 주도록 한다.
decipyer = AES.new(settings.AES_KEY.encode("utf8"), AES.MODE_ECB)
k = base64.b64decode(k)
msg_dec = decipyer.decrypt(k)
msg_dec = unpad(msg_dec, 16)
obj = json.loads(msg_dec)
else:
if settings.ENVIRONMENT == 'prod':
raise HTTPException(status_code=404)
else:
pass
if self.need_cipyer or (self.path == '/time'):
cipyer = AES.new(settings.AES_KEY.encode("utf8"), AES.MODE_ECB)
msg_enc = pad(body, 16)
msg_enc = cipyer.encrypt(msg_enc)
k = base64.b64encode(msg_enc)
message['body'] = k
headers = MutableHeaders(raw=self.start_message["headers"])
headers.update({'content-length':str(len(k))})
await self.send(self.start_message)
await self.send(message)
if settings.ENVIRONMENT != 'prod':
try:
recoder = NetworkRecoder()
recoder.uid = self.uid
recoder.url = self.path
recoder.req_origin = self.strRequest
recoder.status = self.status
recoder.res_data = self.strResponse
recoder.res_final = None
recoder.latency = self.time - int(time.time())
recoder.created = datetime.now()
db = get_db().__next__()
db.connection()
db.add(recoder)
db.commit()
finally:
db.close()
if self.enable_compress:
msg_enc = gzip.compress(body)
cipyer = AES.new(settings.AES_KEY.encode("utf8"), AES.MODE_ECB)
msg_enc = pad(msg_enc, 16)
msg_enc = cipyer.encrypt(msg_enc)
k = base64.b64encode(msg_enc)
message['body'] = k
headers = MutableHeaders(raw=self.start_message["headers"])
headers.update({'content-length':str(len(k))})
for header in headers.raw:
if header[0] == x:
if header[1].decode('utf-8').find(',') >= 0:
origin_ip, forward_ip = header[1].decode('utf-8').split(', ')
headers['origin_ip'] = origin_ip
else:
origin_ip = header[1].decode('utf-8')
headers['origin_ip'] = origin_ip