質問に答えてくれるAI「Flan-t5」を動かします。コードをコピペするだけで誰でも使えるようにしておきました。(GPU次第です)
また、連続して質問可能なようにしただけでなく、翻訳AI「NLLB200」を合わせて使うことで日本語対応させました。
最終的にこのようなやり取りが可能になります。
質問をどうぞ(日本語もOK): 4+3=
=============== 回答 ===============
7
質問をどうぞ(日本語もOK): りんごとは何ですか?
翻訳された質問What is an apple?
翻訳前の回答<pad> a fruit</s>
=============== 回答 ===============
果物
質問をどうぞ(日本語もOK): ニューヨークについて教えて
翻訳された質問Tell me about New York.
翻訳前の回答<pad> New York is a city in the United States of America, located on Long Island in the New York metropolitan area, and is the most populous city in the state, with a population of over 4.7 million people in 2010.</s>
=============== 回答 ===============
ニューヨークは、ニューヨーク大都市圏のロングアイランドに位置するアメリカ合衆国の都市であり、州で最も人口が多い都市であり、2010年には400万人以上の人口を抱えています。
1. 導入編
モデルの種類
- Flan-T5 small (80M)
- Flan-T5 base (250M)
- Flan-T5 large (780M)
- Flan-T5 XL (3B)
- Flan-T5 XXL (11B)
今回は「Flan-T5 XL (3B)」を使用します。
システム負荷
XL(3B)で大体11GBのVRAMを使用します。
環境に合わせてモデルを選びましょう、
環境構築
以下のコマンドで transformers をインストールします。
pip install transformers
グラボのCUDAコアを使いたい方は以下のコマンドで PyTorch も追加でインストールします。
ただし、PyTorchについてはCUDAバージョンなどに左右されるため、エラーなど出たら公式を確認してください。
ここでは一例として当方の環境(上記画像の通り)のコマンドを載せておきます。
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
2. 実際に動かしてみる
動作確認用コード
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
if torch.cuda.is_available():
model = model.to("cuda")
with torch.no_grad():
outputs = model.generate(
input_ids.to(model.device),
max_new_tokens=100,
bad_words_ids=[[1], [5]],
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print(tokenizer.decode(outputs[0]))
初回実行時は、XL(3B)の場合、モデルなどのダウンロード(約10GB)が行われるため気長に待ちましょう。
ちなみにsmall(80M)は300MBです。
可愛いですね。
上記コードを実行すると、 このように出力されるハズです。
<pad> Wie alt sind Sie?</s>
入力プロンプト「translate English to German: How old are you?」
(How old are you?をドイツ語に翻訳してください)に対して、
「Wie alt sind Sie?」としっかり答えていますね。
使いやすくする
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
if torch.cuda.is_available():
model = model.to("cuda")
while True:
input_text = input("質問をどうぞ(日本語はNG): ")
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
with torch.no_grad():
outputs = model.generate(
input_ids.to(model.device),
max_new_tokens=100,
bad_words_ids=[[1], [5]],
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
result = tokenizer.decode(outputs[0])
print("=" * 15, "回答", "=" * 15)
print(result.replace("</s>", "").replace("<pad>", "") + "\n")
コマンドプロンプトで何度も質疑応答が繰り返せるようにしました。
日本語非対応のため、質問は英語でしましょう。
こんな感じで動作します。
質問をどうぞ(日本語はNG): Explain me about New York.
=============== 回答 ===============
New York is a city in the U s of America, located on Long Island in the U s of America, and is the most populous city in the state of New York, with a population of over 212,790, as of the 2010 census, and the most populous city in the state of New York, with a population of over 212,790, as of the 2010 census, and the most populous city in the state
質問をどうぞ(日本語はNG): What is the highest mountain in Japan?
=============== 回答 ===============
mount fuji
質問をどうぞ(日本語はNG): 1+5=
=============== 回答 ===============
6
かなりいい感じです。
でもせっかくなら日本語で質疑応答したいですね。。。
質問をどうぞ(日本語はNG): 日本で一番高い山は?
=============== 回答 ===============
<unk>?
現状だとこのように質問できません。
間に翻訳AI「NLLB200」(Facebook(現Meta))の翻訳を入れてみましょう!
応用編 日本語未対応のT5を翻訳AIを噛ませて実行
上記ページ通りにしてNLLB200が動作するようにしておきます。
その後以下のコードをコピペすれば無理やりFlan T5を日本語対応させて動かせます。
Flan-T5-XL(3B)とNLLB200(distilled-600M)をあわせてVRAM14GBくらいです。
環境に合わせてモデルを変えましょう。
- Flan-T5 small (80M)
- Flan-T5 base (250M)
- Flan-T5 large (780M)
- Flan-T5 XL (3B)
- Flan-T5 XXL (11B)
import string
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
def is_halfwidth(s): # 半角のみならTrue
return all(
c in string.ascii_letters + string.digits + string.punctuation + " " for c in s
)
def t5(input_text):
input_ids = tokenizer_T5(input_text, return_tensors="pt").input_ids
with torch.no_grad():
outputs = model_T5.generate(
input_ids.to(model_T5.device),
max_new_tokens=100,
bad_words_ids=[[1], [5]],
pad_token_id=tokenizer_T5.pad_token_id,
bos_token_id=tokenizer_T5.bos_token_id,
eos_token_id=tokenizer_T5.eos_token_id,
)
result = tokenizer_T5.decode(outputs[0])
return result
def nllb200(input_lang, output_lang, article):
tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M", src_lang=input_lang
)
inputs = tokenizer(article, return_tensors="pt")
with torch.no_grad():
output_ids = model_nllb.generate(
**inputs.to(model_nllb.device),
forced_bos_token_id=tokenizer.lang_code_to_id[output_lang],
max_length=200,
)
result = tokenizer.decode(output_ids.tolist()[0])
result = result.replace("</s>", "").replace(f"{output_lang} ", "")
if output_lang == "jpn_Jpan":
result = result.replace(".", "。").replace(",", "、").replace(" ", "")
return result
tokenizer_T5 = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model_T5 = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
model_nllb = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
if torch.cuda.is_available():
model_T5 = model_T5.to("cuda")
model_nllb = model_nllb.to("cuda")
while True:
input_text = input("質問をどうぞ(日本語もOK): ")
if is_halfwidth(input_text) == True:
answer = t5(input_text)
elif is_halfwidth(input_text) == False:
input_text = nllb200("jpn_Jpan", "eng_Latn", input_text)
print("翻訳された質問" + input_text)
answer = t5(input_text)
print("翻訳前の回答" + answer)
answer = nllb200(
"eng_Latn", "jpn_Jpan", answer.replace("</s>", "").replace("<pad>", "")
)
print("=" * 15, "回答", "=" * 15)
print(answer.replace("</s>", "").replace("<pad>", "") + "\n")
上記コードを実行するとこのようになります。
質問をどうぞ(日本語もOK): ニューヨークについて教えて
翻訳された質問Tell me about New York.
翻訳前の回答<pad> New York is a city in the United States of America, located on Long Island in the New York metropolitan area, and is the most populous city in the state, with a population of over 4.7 million people in 2010.</s>
=============== 回答 ===============
ニューヨークは、ニューヨーク大都市圏のロングアイランドに位置するアメリカ合衆国の都市であり、州で最も人口が多い都市であり、2010年には400万人以上の人口を抱えています。
質問をどうぞ(日本語もOK): りんごとは何ですか?
翻訳された質問What is an apple?
翻訳前の回答<pad> a fruit</s>
=============== 回答 ===============
果物
翻訳をかませるので精度は落ちますが、そこそこ使い勝手が良いです。
他にも色々とAI関連の記事を書いています。合わせてご覧ください。
画像生成AI「Stable Diffusion」を動かす方法
コード生成AI「santacoder」をWindowsローカル環境で動かす方法
文章生成AI「GPT-2/rinna」をWindowsローカルで動かすには