Skip to content

Commit 3e4019f

Browse files
committed
first draft
1 parent 7cca91d commit 3e4019f

File tree

5 files changed

+66
-0
lines changed

5 files changed

+66
-0
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
# Translator
22
Translate from one language to another.
3+
4+
# License
5+
6+
This project is ditributed under [Mozilla Public License 2.0](LICENSE)

translator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from translator.translate import Translator

translator/__main__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import sys
2+
from translator.main import main
3+
4+
if __name__ == "__main__":
5+
try:
6+
main()
7+
sys.exit(0)
8+
except KeyboardInterrupt:
9+
sys.exit(1)

translator/main.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import sys
2+
3+
from argparse import ArgumentParser
4+
from translator import Translator
5+
6+
def parse_arguments():
7+
argument_parse = ArgumentParser(description="Translate from one language to another.")
8+
argument_parse.add_argument('sentence', nargs="*", help="Something to translate.")
9+
argument_parse.add_argument('-s', '--source', default="en", help="Source language to translate.")
10+
argument_parse.add_argument('-t', '--target', default="fr", help="Target language to translate.")
11+
argument_parse.add_argument('-l', '--max_length', default=500, help="Max length of output.")
12+
argument_parse.add_argument('-m', '--model_id', default="facebook/nllb-200-distilled-600M", help="HuggingFace model ID to use.")
13+
argument_parse.add_argument('-p', '--pipeline', default="translation", help="Pipeline task to use.")
14+
15+
return argument_parse.parse_args()
16+
17+
def main():
18+
args = parse_arguments()
19+
20+
translator = Translator(args.source, args.target, args.max_length, args.model_id, args.pipeline)
21+
print(translator.translate(args.sentence))
22+
23+
if __name__ == "__main__":
24+
try:
25+
main()
26+
sys.exit(0)
27+
except KeyboardInterrupt:
28+
sys.exit(1)

translator/translate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2+
import torch
3+
4+
class Translator:
5+
6+
def __init__(self, source_language, target_language, max_length=500, model_id="facebook/nllb-200-distilled-600M", pipe_line="translation") -> None:
7+
self.model_id = model_id
8+
self.device = 0 if torch.cuda.is_available() else -1
9+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
10+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
11+
self.translator = pipeline(
12+
"translation",
13+
model=self.model,
14+
tokenizer=self.tokenizer,
15+
src_lang=source_language,
16+
tgt_lang=target_language,
17+
max_length=max_length,
18+
device=self.device,
19+
)
20+
21+
def translate(self, sentence: str):
22+
return self.translator(sentence)[0]['translation_text']
23+
24+

0 commit comments

Comments
 (0)