stable-code-3b を試す | 株式会社バンコム

stable-code-3b を試す

January 23, 2024

Stable Diffusion でファンになってしまった stability ai の Coding 支援LLMというわけですから、試さざるを得ません!

https://ja.stability.ai/blog/stable-code-3b

実装方法は huggingface に書いてあります。

https://huggingface.co/stabilityai/stable-code-3b

環境構築

# 仮想環境構築
python3 -m venv venv
venv/bin/activate

pip install accelerate
pip install transformers

サンプルコード

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-3b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
  "stabilityai/stable-code-3b",
  trust_remote_code=True,
  torch_dtype="auto",
)
model.cuda()
inputs = tokenizer("import torch\nimport torch.nn as nn", return_tensors="pt").to(model.device)
tokens = model.generate(
  **inputs,
  max_new_tokens=48,
  temperature=0.2,
  do_sample=True,
)
print(tokenizer.decode(tokens[0], skip_special_tokens=True))

これは、サンプルそのままです。

実行結果

import torch
import torch.nn as nn
import torch.nn.functional as F

from. import utils
from. import layers
from. import losses
from. import models
from. import optimizers
from. import schedulers
from. import

おや?途中で終わってる?

書かせてみる

python で、1から10までカウントアップするだけのプログラムを書かせてみましょう。

まずは、入力を受け付けるように改造しました。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# モデルとトークナイザーの初期化
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-3b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "stabilityai/stable-code-3b",
    trust_remote_code=True,
    torch_dtype="auto",
)
model.cuda()

# ユーザー入力の取得
user_input = input("ユーザー入力: ")

# トークン化
inputs = tokenizer(user_input, return_tensors="pt").to(model.device)

# モデルによる応答生成
tokens = model.generate(
    **inputs,
    max_new_tokens=200,
    temperature=0.2,
    do_sample=True,
)

# 応答のデコード
response = tokenizer.decode(tokens[0], skip_special_tokens=True)

# 応答の表示
print(response)

実行結果

$ python myllm2.py 
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.03it/s]
ユーザー入力: python code of counting up from 1 to 100.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
python code of counting up from 1 to 100.

# In[2]:


for i in range(1,101):
    print(i)


# In[3]:


# Write a Python program to count the number of items in the list.

# In[4]:


list1 = [1,2,3,4,5,6,7,8,9,10]

print(len(list1))


# In[5]:


# Write a Python program to count the number of items of each type in a list.

# In[6]:


list1 = [1,2,3,4,5,6,7,8,9,10]

print(len(set(list1)))


# In[7]:


# Write a Python program to count the number of items of each type in a list.

# In[8]:


list1 =

もうっちょ長いものを書かせてみる

https://www.yahoo.co.jp/ のタイトルを Python Selenium で取得するコマンドを書かせてみましょう。
書けるかな?

ユーザー入力: Python code for Scraping a title tag of www.yahoo.co.jp by selenium.        
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Python code for Scraping a title tag of www.yahoo.co.jp by selenium.

# -*- coding: utf-8 -*-

from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import TimeoutException

# ブラウザの設定
browser = webdriver.Firefox()

# ページ遷移の待ち時間
wait = WebDriverWait(browser, 10)

# ページ遷移の待ち時間
wait = WebDriverWait(browser, 10)

# ページ遷移の待ち時間
wait = WebDriverWait(browser, 10)

# ページ遷移の

うーん、、、残念!