gRPC

2019-12-09

tensorflow serving에서 제공하는 gRPC 통신을 사용해보기전 간단히 gRPC에 대한 개념을 공부하고, 간단한 예제를 구현해보았습니다.

gRPC

RPC는 RPC(Remote Procedure Call)는 원격 컴퓨터나 프로세스에 존재하는 함수를 호출하는데 사용하는 프로토콜입니다. 원격에 있는 함수를 동일 프로세스에 존재하는 함수를 호출하는 것 처럼 직접 호출할 수 있습니다. gRPC는 Google 의 RPC 구현체입니다.

기본적으로 gRPC는 서비스 인터페이스와 메시지의 구조를 모두 설명하기 위해 ProtoBuf 를 IDL (Interface Definition Language)로 사용합니다. ProtoBuf는 구글에서 만들고 사용하는 데이터 직렬화 라이브러리 입니다. IDL로 정의된 메시지를 기반으로 protoc 라는 컴파일러를 사용해서 원하는 언어의 코드를 생성할 수 있습니다.

(gRPC와 JSON을 사용 하는 HTTP Api 간 비교표)

gRPC에 대한 더 자세한 소개와 장점에 대해서는 https://medium.com/@goinhacker/microservices-with-grpc-d504133d191d 에 잘 설명되어있습니다.

Tutorial

https://github.com/grpc/grpc/tree/v1.25.0/examples/python/route_guide 을 참고하여 간단한 예제를 만들어보았습니다.

먼저 필요한 라이브러리 설치를 위해 아래와 같이 grpcio-tools 를 설치해 줍니다.

pip install grpcio-tools

예제에서는 gRPC 통신을 통해 두 수의 합을 구해주는 서비스를 제공하는 서버와, 클라이언트를 구현합니다.

service.proto

먼저 protos 디렉토리 아래 service.proto라는 파일을 만들어서 두 수의 합을 구하는 서비스의 인터페이스와 message의 형태를 정의합니다.

syntax = "proto3";

package test;

service Test {
  rpc GetSum(Data) returns (Sum) {}
}

message Data {
  int32 a = 1;
  int32 b = 2;
}

message Sum {
  int32 sum = 1;
}

다음에서 구현할 서버와 클라이언트에서 위 service.proto 에 명세 대로 gRPC를 이용하여 통신하게 됩니다.

run_codegen.py

from grpc_tools import protoc

protoc.main((
    '',
    '-Iprotos',
    '--python_out=python_out',
    '--grpc_python_out=python_out',
    'protos/service.proto',
))

protos/service.proto 에 정의된 명세대로 python stub 코드를 python_out 이라는 디렉토리에 생성해주는 코드입니다.

run_codegen.py 을 통해 만들어지는 코드는 다음과 같습니다.

  • python_out/service_pb2.py
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler.  DO NOT EDIT!
# source: service.proto

import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()




DESCRIPTOR = _descriptor.FileDescriptor(
  name='service.proto',
  package='test',
  syntax='proto3',
  serialized_options=None,
  serialized_pb=_b('\n\rservice.proto\x12\x04test\"\x1c\n\x04\x44\x61ta\x12\t\n\x01\x61\x18\x01 \x01(\x05\x12\t\n\x01\x62\x18\x02 \x01(\x05\"\x12\n\x03Sum\x12\x0b\n\x03sum\x18\x01 \x01(\x05\x32)\n\x04Test\x12!\n\x06GetSum\x12\n.test.Data\x1a\t.test.Sum\"\x00\x62\x06proto3')
)




_DATA = _descriptor.Descriptor(
  name='Data',
  full_name='test.Data',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  fields=[
    _descriptor.FieldDescriptor(
      name='a', full_name='test.Data.a', index=0,
      number=1, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
    _descriptor.FieldDescriptor(
      name='b', full_name='test.Data.b', index=1,
      number=2, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto3',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=23,
  serialized_end=51,
)


_SUM = _descriptor.Descriptor(
  name='Sum',
  full_name='test.Sum',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  fields=[
    _descriptor.FieldDescriptor(
      name='sum', full_name='test.Sum.sum', index=0,
      number=1, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto3',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=53,
  serialized_end=71,
)

DESCRIPTOR.message_types_by_name['Data'] = _DATA
DESCRIPTOR.message_types_by_name['Sum'] = _SUM
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), {
  'DESCRIPTOR' : _DATA,
  '__module__' : 'service_pb2'
  # @@protoc_insertion_point(class_scope:test.Data)
  })
_sym_db.RegisterMessage(Data)

Sum = _reflection.GeneratedProtocolMessageType('Sum', (_message.Message,), {
  'DESCRIPTOR' : _SUM,
  '__module__' : 'service_pb2'
  # @@protoc_insertion_point(class_scope:test.Sum)
  })
_sym_db.RegisterMessage(Sum)



_TEST = _descriptor.ServiceDescriptor(
  name='Test',
  full_name='test.Test',
  file=DESCRIPTOR,
  index=0,
  serialized_options=None,
  serialized_start=73,
  serialized_end=114,
  methods=[
  _descriptor.MethodDescriptor(
    name='GetSum',
    full_name='test.Test.GetSum',
    index=0,
    containing_service=None,
    input_type=_DATA,
    output_type=_SUM,
    serialized_options=None,
  ),
])
_sym_db.RegisterServiceDescriptor(_TEST)

DESCRIPTOR.services_by_name['Test'] = _TEST

# @@protoc_insertion_point(module_scope)

  • python_out/service_pb2_grpc.py
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler.  DO NOT EDIT!
# source: service.proto

import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()




DESCRIPTOR = _descriptor.FileDescriptor(
  name='service.proto',
  package='test',
  syntax='proto3',
  serialized_options=None,
  serialized_pb=_b('\n\rservice.proto\x12\x04test\"\x1c\n\x04\x44\x61ta\x12\t\n\x01\x61\x18\x01 \x01(\x05\x12\t\n\x01\x62\x18\x02 \x01(\x05\"\x12\n\x03Sum\x12\x0b\n\x03sum\x18\x01 \x01(\x05\x32)\n\x04Test\x12!\n\x06GetSum\x12\n.test.Data\x1a\t.test.Sum\"\x00\x62\x06proto3')
)




_DATA = _descriptor.Descriptor(
  name='Data',
  full_name='test.Data',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  fields=[
    _descriptor.FieldDescriptor(
      name='a', full_name='test.Data.a', index=0,
      number=1, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
    _descriptor.FieldDescriptor(
      name='b', full_name='test.Data.b', index=1,
      number=2, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto3',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=23,
  serialized_end=51,
)


_SUM = _descriptor.Descriptor(
  name='Sum',
  full_name='test.Sum',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  fields=[
    _descriptor.FieldDescriptor(
      name='sum', full_name='test.Sum.sum', index=0,
      number=1, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto3',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=53,
  serialized_end=71,
)

DESCRIPTOR.message_types_by_name['Data'] = _DATA
DESCRIPTOR.message_types_by_name['Sum'] = _SUM
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), {
  'DESCRIPTOR' : _DATA,
  '__module__' : 'service_pb2'
  # @@protoc_insertion_point(class_scope:test.Data)
  })
_sym_db.RegisterMessage(Data)

Sum = _reflection.GeneratedProtocolMessageType('Sum', (_message.Message,), {
  'DESCRIPTOR' : _SUM,
  '__module__' : 'service_pb2'
  # @@protoc_insertion_point(class_scope:test.Sum)
  })
_sym_db.RegisterMessage(Sum)



_TEST = _descriptor.ServiceDescriptor(
  name='Test',
  full_name='test.Test',
  file=DESCRIPTOR,
  index=0,
  serialized_options=None,
  serialized_start=73,
  serialized_end=114,
  methods=[
  _descriptor.MethodDescriptor(
    name='GetSum',
    full_name='test.Test.GetSum',
    index=0,
    containing_service=None,
    input_type=_DATA,
    output_type=_SUM,
    serialized_options=None,
  ),
])
_sym_db.RegisterServiceDescriptor(_TEST)

DESCRIPTOR.services_by_name['Test'] = _TEST

# @@protoc_insertion_point(module_scope)

server 와 client 코드에서 위 모듈을 불러와서 사용하게 됩니다.

server.py

서버 코드를 아래와 같이 구현합니다.

from concurrent import futures

import grpc

from python_out.service_pb2 import Data, Sum
import python_out.service_pb2_grpc as service_pb2_grpc


class TestServicer(service_pb2_grpc.TestServicer):
    def __init__(self):
        pass

    def GetSum(self, request: Data, context):
        return Sum(sum=request.a + request.b)


def serve():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    service_pb2_grpc.add_TestServicer_to_server(
        TestServicer(), server
    )
    server.add_insecure_port('[::]:50051')
    server.start()
    server.wait_for_termination()


if __name__ == '__main__':
    serve()

service_pb2_grpc.py 에 정의된 stub class TestServicer 를 상속받아서 service.proto 에 정의했던 rpc GetSum(Data) returns (Sum) {} 의 실제 동작 부분을 구현해주면 됩니다. 그리고 serve() 에서 서버를 실행시켜줍니다. server.py 코드를 실행하면 서버가 실행됩니다.

client.py

import grpc


import python_out.service_pb2 as service_pb2
import python_out.service_pb2_grpc as service_pb2_grpc


def get_sum(stub: service_pb2_grpc.TestStub):
    response = stub.GetSum(service_pb2.Data(a=1, b=2))
    print(f"Received sum: {response}")


def run():
    with grpc.insecure_channel('localhost:50051') as channel:
        stub = service_pb2_grpc.TestStub(channel)
        print("-------------- Get Sum --------------")
        get_sum(stub)


if __name__ == '__main__':
    run()

다음은 server에 요청을 보내서 두 수의 합을 구하는 client 코드입니다. service_pb2_grpc 에 선언된 stub class를 이용해서 GetSum 함수를 호출하고 인자로 service_pb2 에 정의된 Data class 의 인스턴스를 넘깁니다. 그리고 response 값으로 두 수의 합을 받아옵니다. client 코드를 실행하면 다음과 같은 결과를 얻을 수 있습니다.

-------------- Get Sum --------------
Received sum: sum: 3

예제는 크게 어렵지 않았는데, gRPC의 개념에 대해서는 익숙하지 않아 이해하기 좀 어려웠습니다. 그렇지만 예제를 한번 작성해보고 나니, 개념과 장점에 대해서 좀 더 잘 와닿았던 것 같습니다. 다음에는 tensorflow serving 에서 제공하는 gRPC 통신으로 변경 시 성능 변화가 얼마나 있는지 테스트 해보려고 합니다.