Skip to content

Commit

Permalink
add mps and fix thammegowda#20
Browse files Browse the repository at this point in the history
  • Loading branch information
hdcola committed Aug 20, 2024
1 parent 1c0227e commit 5bbb696
Showing 1 changed file with 51 additions and 33 deletions.
84 changes: 51 additions & 33 deletions nllb_serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@

from . import log, DEF_MODEL_ID

device = torch.device(torch.cuda.is_available() and 'cuda' or 'cpu')
# Check if CUDA is available
if torch.cuda.is_available():
device = torch.device('cuda')
# Check if MPS is available (for Apple Silicon)
elif torch.backends.mps.is_available():
device = torch.device('mps')
# If CUDA and MPS is not available, use CPU
else:
device = torch.device('cpu')
log.info(f'torch device={device}')

#DEF_MODEL_ID = "facebook/nllb-200-distilled-600M"
# DEF_MODEL_ID = "facebook/nllb-200-distilled-600M"
DEF_SRC_LNG = 'eng_Latn'
DEF_TGT_LNG = 'kan_Knda'
FLOAT_POINTS = 4
exp = None
app = Flask(__name__)
app.json.ensure_ascii = False

bp = Blueprint('nmt', __name__, template_folder='templates', static_folder='static')
bp = Blueprint('nmt', __name__, template_folder='templates',
static_folder='static')


sys_info = {
Expand All @@ -42,22 +51,25 @@
'GPU': '[unavailable]',
}
try:
sys_info['torch']: torch.__version__
sys_info['torch'] = torch.__version__
if torch.cuda.is_available():
sys_info['GPU'] = str(torch.cuda.get_device_properties('cuda'))
sys_info['Cuda Version'] = torch.version.cuda
elif torch.backends.mps.is_available():
sys_info['GPU'] = 'Apple MPS'
sys_info['MPS Version'] = torch.version.__version__
else:
log.warning("CUDA unavailable")
log.warning("CUDA/MPS unavailable")
except:
log.exception("Error while checking if cuda is available")
pass


def render_template(*args, **kwargs):
return flask.render_template(*args, environ=os.environ, **kwargs)


def jsonify(obj):

if obj is None or isinstance(obj, (int, bool, str)):
return obj
elif isinstance(obj, float):
Expand All @@ -66,7 +78,7 @@ def jsonify(obj):
return {key: jsonify(val) for key, val in obj.items()}
elif isinstance(obj, list):
return [jsonify(it) for it in obj]
#elif isinstance(ob, np.ndarray):
# elif isinstance(ob, np.ndarray):
# return _jsonify(ob.tolist())
else:
log.warning(f"Type {type(obj)} maybe not be json serializable")
Expand All @@ -77,9 +89,10 @@ def jsonify(obj):
def favicon():
return send_from_directory(os.path.join(bp.root_path, 'static', 'favicon'), 'favicon.ico')


def attach_translate_route(
model_id=DEF_MODEL_ID, def_src_lang=DEF_SRC_LNG,
def_tgt_lang=DEF_TGT_LNG, **kwargs):
model_id=DEF_MODEL_ID, def_src_lang=DEF_SRC_LNG,
def_tgt_lang=DEF_TGT_LNG, **kwargs):
sys_info['model_id'] = model_id
torch.set_grad_enabled(False)

Expand All @@ -93,7 +106,7 @@ def attach_translate_route(
@lru_cache(maxsize=256)
def get_tokenizer(src_lang=def_src_lang):
log.info(f"Loading tokenizer for {model_id}; src_lang={src_lang} ...")
#tokenizer = AutoTokenizer.from_pretrained(model_id)
# tokenizer = AutoTokenizer.from_pretrained(model_id)
return AutoTokenizer.from_pretrained(model_id, src_lang=src_lang)

@bp.route('/')
Expand All @@ -102,7 +115,6 @@ def index():
def_src_lang=def_src_lang, def_tgt_lang=def_tgt_lang)
return render_template('index.html', **args)


@bp.route("/translate", methods=["POST", "GET"])
def translate():
st = time.time()
Expand All @@ -116,7 +128,7 @@ def translate():
else:
args = request.form

if hasattr(args, 'getlist') :
if hasattr(args, 'getlist'):
sources = args.getlist("source")
else:
sources = args.get("source")
Expand All @@ -126,37 +138,38 @@ def translate():
src_lang = args.get('src_lang') or def_src_lang
tgt_lang = args.get('tgt_lang') or def_tgt_lang
sen_split = args.get('sen_split')

tokenizer = get_tokenizer(src_lang=src_lang)

if not sources:
return "Please submit 'source' parameter", 400

if sen_split:
if not ssplit_lang(src_lang):
return "Sentence splitter for this langauges is not availabe", 400
sources, index = sentence_splitter(src_lang, sources)
return "Sentence splitter for this langauges is not availabe", 400
sources, index = sentence_splitter(src_lang, sources)

max_length = 80
inputs = tokenizer(sources, return_tensors="pt", padding=True)
inputs = {k:v.to(device) for k, v in inputs.items()}
inputs = {k: v.to(device) for k, v in inputs.items()}

translated_tokens = model.generate(
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
max_length = max_length)
output = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)

**inputs, forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_length=max_length)
output = tokenizer.batch_decode(
translated_tokens, skip_special_tokens=True)

if sen_split:
results = []
results = []
for i in range(1, len(index)):
batch = output[index[i-1]:index[i]]
results.append(" ".join(batch))
else:
results = output
results = output

res = dict(source=sources, translation=results,
src_lang = src_lang, tgt_lang=tgt_lang,
time_taken = round(time.time() - st, 3), time_units='s')
src_lang=src_lang, tgt_lang=tgt_lang,
time_taken=round(time.time() - st, 3), time_units='s')

return flask.jsonify(jsonify(res))

Expand All @@ -169,12 +182,17 @@ def parse_args():
parser = ArgumentParser(
prog="nllb-serve",
description="Deploy NLLB model to a RESTful server",
epilog=f'Loaded from {__file__}. Source code: https://github.com/thammegowda/nllb-serve',
epilog=f'Loaded from {
__file__}. Source code: https://github.com/thammegowda/nllb-serve',
formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("-d", "--debug", action="store_true", help="Run Flask server in debug mode")
parser.add_argument("-p", "--port", type=int, help="port to run server on", default=6060)
parser.add_argument("-ho", "--host", help="Host address to bind.", default='0.0.0.0')
parser.add_argument("-b", "--base", help="Base prefix path for all the URLs. E.g., /v1")
parser.add_argument("-d", "--debug", action="store_true",
help="Run Flask server in debug mode")
parser.add_argument("-p", "--port", type=int,
help="port to run server on", default=6060)
parser.add_argument(
"-ho", "--host", help="Host address to bind.", default='0.0.0.0')
parser.add_argument(
"-b", "--base", help="Base prefix path for all the URLs. E.g., /v1")
parser.add_argument("-mi", "--model_id", type=str, default=DEF_MODEL_ID,
help="model ID; see https://huggingface.co/models?other=nllb")
parser.add_argument("-msl", "--max-src-len", type=int, default=250,
Expand Down Expand Up @@ -207,4 +225,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit 5bbb696

Please sign in to comment.