Node.js で gRPC を試す

gRPC Server Reflection のクライアント処理」では Node.js で gRPC クライアントを実装しましたが、今回はサーバー側も実装してみます。

サンプルコードは http://github.com/fits/try_samples/tree/master/blog/20201115/

はじめに

gRPC on Node.js では、以下の 2通りの手法が用意されており、それぞれ使用するパッケージが異なります。

  • (a) 動的コード生成 (@grpc/proto-loader パッケージを使用)
  • (b) 静的コード生成 (grpc-tools パッケージを使用)

更に、gRPC の実装ライブラリとして以下の 2種類が用意されており、どちらかを使う事になります。

  • C-based Client and Server (grpc パッケージを使用)
  • Pure JavaScript Client (@grpc/grpc-js パッケージを使用)

@grpc/grpc-js は現時点で Pure JavaScript Client と表現されていますが、クライアントだけではなくサーバーの実装にも使えます。

ここでは、Pure JavaScript 実装の @grpc/grpc-js を使って、(a) と (b) の両方を試してみます。

サービス定義(proto ファイル)

gRPC のサービス定義として下記ファイルを使用します。

Unary RPC(1リクエスト / 1レスポンス)と Server streaming RPC(1リクエスト / 多レスポンス)、message の oneofgoogle.protobuf.Empty の扱い等を確認するような内容にしてみました。

proto/item.proto
syntax = "proto3";

import "google/protobuf/empty.proto";

package item;

message AddItemRequest {
    string item_id = 1;
    uint64 price = 2;
}

message ItemRequest {
    string item_id = 1;
}

message Item {
    string item_id = 1;
    uint64 price = 2;
}

message ItemSubscribeRequest {
}

message AddedItem {
    string item_id = 1;
    uint64 price = 2;
}

message RemovedItem {
    string item_id = 1;
}

message ItemEvent {
    oneof event {
        AddedItem added = 1;
        RemovedItem removed = 2;
    }
}

service ItemManage {
    rpc AddItem(AddItemRequest) returns (google.protobuf.Empty);
    rpc RemoveItem(ItemRequest) returns (google.protobuf.Empty);
    rpc GetItem(ItemRequest) returns (Item);

    rpc Subscribe(ItemSubscribeRequest) returns (stream ItemEvent);
}

(a) 動的コード生成(@grpc/proto-loader)

まずは、@grpc/proto-loader を使った動的コード生成を試します。

インストール

@grpc/proto-loader@grpc/grpc-js をインストールしておきます。

> npm install --save @grpc/proto-loader @grpc/grpc-js

サーバー実装

proto-loader の loadSync 関数 ※ で proto ファイルをロードした結果を grpc-js の loadPackageDefinition で処理する事で型定義などを動的に生成します。

 ※ 非同期版の load 関数も用意されています

addService で gRPC のサービス定義と処理をマッピングし、bindAsync 後に start を呼び出す事でサーバー処理を開始します。

proto ファイルで定義したメッセージ型と同じフィールドを持つ JavaScript オブジェクトを gRPC のリクエストやレスポンスで使う事ができるようです。

Unary RPC の場合は、第二引数の callback へ失敗時の値と成功時の値をそれぞれ渡す事で処理結果を返します。

任意のエラーを返したい場合は、code で gRPC のステータスコードを、details でエラー内容を指定します。

google.protobuf.Empty の箇所は null もしくは undefined で代用できるようです。

Server streaming RPC の場合は、第一引数(下記コードでは call)の write を呼び出す事で処理結果を返す事ができます。

クライアントが途中で切断したりすると cancelled が発生するようになっており、cancelled 発生後に write を呼び出してもエラー等は発生しないようになっていました。

server.js
const protoLoader = require('@grpc/proto-loader')
const grpc = require('@grpc/grpc-js')

const protoFile = './proto/item.proto'
// proto ファイルのロード
const pd = protoLoader.loadSync(protoFile)
// gPRC 用の動的な型定義生成
const proto = grpc.loadPackageDefinition(pd)

let store = []
let subscribeList = []

const findItem = itemId => store.find(i => i.itemId == itemId)

const addItem = (itemId, price) => {
    if (findItem(itemId)) {
        return undefined
    }

    const item = { itemId, price }

    store.push(item)

    return item
}

const removeItem = itemId => {
    const item = findItem(itemId)

    if (item) {
        store = store.filter(i => i.itemId != item.itemId)
    }

    return item
}
// ItemEvent の配信
const publishEvent = event => {
    console.log(`*** publish event: ${JSON.stringify(event)}`)
    subscribeList.forEach(s => s.write(event))
}

const server = new grpc.Server()
// サービス定義と処理のマッピング
server.addService(proto.item.ItemManage.service, {
    AddItem(call, callback) {
        const itemId = call.request.itemId
        const price = call.request.price

        const item = addItem(itemId, price)

        if (item) {
            callback()
            publishEvent({ added: { itemId, price }})
        }
        else {
            const err = { code: grpc.status.ALREADY_EXISTS, details: 'exists item' }
            callback(err)
        }
    },
    RemoveItem(call, callback) {
        const itemId = call.request.itemId

        if (removeItem(itemId)) {
            callback()
            publishEvent({ removed: { itemId }})
        }
        else {
            const err = { code: grpc.status.NOT_FOUND, details: 'item not found' }
            callback(err)
        }
    },
    GetItem(call, callback) {
        const itemId = call.request.itemId
        const item = findItem(itemId)

        if (item) {
            callback(null, item)
        }
        else {
            const err = { code: grpc.status.NOT_FOUND, details: 'item not found' }
            callback(err)
        }
    },
    Subscribe(call) {
        console.log('*** subscribed')
        subscribeList.push(call)
        // クライアント切断時の処理
        call.on('cancelled', () => {
            console.log('*** unsubscribed')
            subscribeList = subscribeList.filter(s => s != call)
        })
    }
})

server.bindAsync(
    '127.0.0.1:50051',
    grpc.ServerCredentials.createInsecure(),
    (err, port) => {
        if (err) {
            console.error(err)
            return
        }

        console.log(`start server: ${port}`)
        // 開始
        server.start()
    }
)

クライアント実装1

まずは、Unary RPC の API のみ(Subscribe 以外)を呼び出すクライアントを実装してみます。

loadPackageDefinition を実施するところまではサーバーと同じです。

Unary RPC はコールバック関数を伴ったメソッドとして用意されますが、このメソッドに Node.js の util.promisify を直接適用すると不都合が生じたため、Promise 化は自前の関数(下記の promisify)で実施するようにしました。

client.js
const protoLoader = require('@grpc/proto-loader')
const grpc = require('@grpc/grpc-js')

const protoFile = './proto/item.proto'

const pd = protoLoader.loadSync(protoFile)
const proto = grpc.loadPackageDefinition(pd)

const id = process.argv[2]

const client = new proto.item.ItemManage(
    '127.0.0.1:50051',
    grpc.credentials.createInsecure()
)
// Unary RPC の Promise 化
const promisify = (obj, methodName) => args => 
    new Promise((resolve, reject) => {
        obj[methodName](args, (err, res) => {
            if (err) {
                reject(err)
            }
            else {
                resolve(res)
            }
        })
    })

const addItem = promisify(client, 'AddItem')
const removeItem = promisify(client, 'RemoveItem')
const getItem = promisify(client, 'GetItem')

const printItem = item => {
    console.log(`id = ${item.itemId}, price = ${item.price}`)
}

const run = async () => {
    await addItem({ itemId: `${id}_item-1`, price: 100 })

    const item1 = await getItem({ itemId: `${id}_item-1` })
    printItem(item1)

    await addItem({ itemId: `${id}_item-2`, price: 20 })

    const item2 = await getItem({ itemId: `${id}_item-2` })
    printItem(item2)

    await addItem({ itemId: `${id}_item-1`, price: 50 })
        .catch(err => console.error(`*** ERROR = ${err.message}`))

    await removeItem({ itemId: `${id}_item-1` })

    await getItem({ itemId: `${id}_item-1` })
        .catch(err => console.error(`*** ERROR = ${err.message}`))

    await removeItem({ itemId: `${id}_item-2` })
}

run().catch(err => console.error(err))

クライアント実装2

次は Server streaming RPC のクライアント実装です。

client_subscribe.js
const protoLoader = require('@grpc/proto-loader')
const grpc = require('@grpc/grpc-js')

const protoFile = './proto/item.proto'

const pd = protoLoader.loadSync(protoFile)
const proto = grpc.loadPackageDefinition(pd)

const client = new proto.item.ItemManage(
    '127.0.0.1:50051',
    grpc.credentials.createInsecure()
)

const stream = client.Subscribe({})
// メッセージ受信時
stream.on('data', event => {
    console.log(`*** received event = ${JSON.stringify(event)}`)
})

// サーバー終了時
stream.on('end', () => console.log('*** stream end'))
stream.on('error', err => console.log(`*** stream error: ${err}`))

動作確認

server.js の実行後、client_subscribe.js を 2つ起動した後に client.js を実行してみます。

Server 実行
> node server.js
start server: 50051
Client2-1 実行
> node client_subscribe.js
Client2-2 実行
> node client_subscribe.js
Client1 実行
> node client.js a1
id = a1_item-1, price = 100
id = a1_item-2, price = 20
*** ERROR = 6 ALREADY_EXISTS: exists item
*** ERROR = 5 NOT_FOUND: item not found

この時点で出力内容は以下のようになりました。

Server 出力内容
> node server.js
start server: 50051
*** subscribed
*** subscribed
*** publish event: {"added":{"itemId":"a1_item-1","price":100}}
*** publish event: {"added":{"itemId":"a1_item-2","price":20}}
*** publish event: {"removed":{"itemId":"a1_item-1"}}
*** publish event: {"removed":{"itemId":"a1_item-2"}}
Client2-1、Client2-2 出力内容
> node client_subscribe.js
*** received event = {"added":{"itemId":"a1_item-1","price":{"low":100,"high":0,"unsigned":true}}}
*** received event = {"added":{"itemId":"a1_item-2","price":{"low":20,"high":0,"unsigned":true}}}
*** received event = {"removed":{"itemId":"a1_item-1"}}
*** received event = {"removed":{"itemId":"a1_item-2"}}

Client2-2 を Ctrl + c で終了後に、Server も Ctrl + c で終了すると以下のようになり、Client2-1 のプロセスは終了しました。

Server 出力内容
> node server.js
start server: 50051
・・・
*** publish event: {"removed":{"itemId":"a1_item-2"}}
*** unsubscribed
^C
Client2-1 出力内容
> node client_subscribe.js
・・・
*** received event = {"removed":{"itemId":"a1_item-2"}}
*** stream error: Error: 13 INTERNAL: Received RST_STREAM with code 2 (Internal server error)
*** stream end

特に問題はなく、正常に動作しているようです。

(b) 静的コード生成(grpc-tools)

grpc-tools を使った静的コード生成を試します。

インストールとコード生成

grpc-tools@grpc/grpc-jsgoogle-protobuf をインストールしておきます。

> npm install --save-dev grpc-tools
・・・
> npm install --save @grpc/grpc-js google-protobuf
・・・

grpc-tools をインストールする事で使えるようになる grpc_tools_node_protoc コマンドで proto ファイルからコードを生成します。

grpc_tools_node_protoc コマンドは内部的に protoc コマンドを grpc_node_plugin プラグインを伴って呼び出すようになっています。

--grpc_out でサービス定義用のファイル xxx_grpc_pb.js が生成され、--js_out でメッセージ定義用のファイルが生成されます。

サービス定義 xxx_grpc_pb.js は --js_out で import_style=commonjs オプションを指定する事を前提としたコードになっています。※

 ※ import_style=commonjs オプションを指定した際に生成される
    xxx_pb.js を参照するようになっている

また、--grpc_out はデフォルトで grpc パッケージ用のコードを生成するため、ここでは grpc_js オプションを指定して @grpc/grpc-js 用のコードを生成するようにしています。

静的コード生成例(grpc_tools_node_protoc コマンド)
> mkdir generated

> grpc_tools_node_protoc --grpc_out=grpc_js:generated --js_out=import_style=commonjs:generated proto/*.proto
・・・

サーバー実装

(a) の場合と処理内容に大きな違いはありませんが、リクエストやレスポンスでは生成された型を使います。

アクセサメソッド(getter、setter)で値の取得や設定ができるようになっており、new 時に配列として全フィールドの値を指定する事もできるようです。 JavaScript オブジェクトへ変換したい場合は toObject メソッドを使用します。

addService でマッピングする際のメソッド名の一文字目が小文字になっています。

proto ファイルで定義したサービス名の後に Service を付けた型(ここでは ItemManageService)がサーバー処理用、Client を付けた型がクライアント処理用の型定義となるようです。

server.js
const grpc = require('@grpc/grpc-js')

const { Item, AddedItem, RemovedItem, ItemEvent } = require('./generated/proto/item_pb')
const { ItemManageService } = require('./generated/proto/item_grpc_pb')
const { Empty } = require('google-protobuf/google/protobuf/empty_pb')

let store = []
let subscribeList = []

const findItem = itemId => store.find(i => i.getItemId() == itemId)

const addItem = (itemId, price) => {
    if (findItem(itemId)) {
        return undefined
    }

    const item = new Item([itemId, price])

    store.push(item)

    return item
}

const removeItem = itemId => {
    const item = findItem(itemId)

    if (item) {
        store = store.filter(i => i.getItemId() != item.getItemId())
    }

    return item
}

const createAddedEvent = (itemId, price) => {
    const event = new ItemEvent()
    event.setAdded(new AddedItem([itemId, price]))

    return event
}

const createRemovedEvent = itemId => {
    const event = new ItemEvent()
    event.setRemoved(new RemovedItem([itemId]))

    return event
}

const publishEvent = event => {
    // toObject で JavaScript オブジェクトへ変換
    console.log(`*** publish event: ${JSON.stringify(event.toObject())}`)
    subscribeList.forEach(s => s.write(event))
}

const server = new grpc.Server()

server.addService(ItemManageService, {
    addItem(call, callback) {
        const itemId = call.request.getItemId()
        const price = call.request.getPrice()

        const item = addItem(itemId, price)

        if (item) {
            callback(null, new Empty())
            publishEvent(createAddedEvent(itemId, price))
        }
        else {
            const err = { code: grpc.status.ALREADY_EXISTS, details: 'exists item' }
            callback(err)
        }
    },
    removeItem(call, callback) {
        const itemId = call.request.getItemId()

        if (removeItem(itemId)) {
            callback(null, new Empty())
            publishEvent(createRemovedEvent(itemId))
        }
        else {
            const err = { code: grpc.status.NOT_FOUND, details: 'item not found' }
            callback(err)
        }
    },
    getItem(call, callback) {
        const itemId = call.request.getItemId()
        const item = findItem(itemId)

        if (item) {
            callback(null, item)
        }
        else {
            const err = { code: grpc.status.NOT_FOUND, details: 'item not found' }
            callback(err)
        }
    },
    subscribe(call) {
        console.log('*** subscribed')
        subscribeList.push(call)

        call.on('cancelled', () => {
            console.log('*** unsubscribed')
            subscribeList = subscribeList.filter(s => s != call)
        })
    }
})

server.bindAsync(
    ・・・
)

クライアント実装1

生成された型を使う点とメソッド名の先頭が小文字になっている点を除くと、基本的に (a) と同じです。

client.js
const grpc = require('@grpc/grpc-js')

const { AddItemRequest, ItemRequest } = require('./generated/proto/item_pb')
const { ItemManageClient } = require('./generated/proto/item_grpc_pb')

const id = process.argv[2]

const client = new ItemManageClient(
    '127.0.0.1:50051',
    grpc.credentials.createInsecure()
)

const promisify = (obj, methodName) => args => 
    new Promise((resolve, reject) => {
        ・・・
    })

const addItem = promisify(client, 'addItem')
const removeItem = promisify(client, 'removeItem')
const getItem = promisify(client, 'getItem')

const printItem = item => {
    console.log(`id = ${item.getItemId()}, price = ${item.getPrice()}`)
}

const run = async () => {
    await addItem(new AddItemRequest([`${id}_item-1`, 100]))

    const item1 = await getItem(new ItemRequest([`${id}_item-1`]))
    printItem(item1)

    await addItem(new AddItemRequest([`${id}_item-2`, 20]))

    const item2 = await getItem(new ItemRequest([`${id}_item-2`]))
    printItem(item2)

    await addItem(new AddItemRequest([`${id}_item-1`, 50]))
        .catch(err => console.error(`*** ERROR = ${err.message}`))

    await removeItem(new ItemRequest([`${id}_item-1`]))

    await getItem(new ItemRequest([`${id}_item-1`]))
        .catch(err => console.error(`*** ERROR = ${err.message}`))

    await removeItem(new ItemRequest([`${id}_item-2`]))
}

run().catch(err => console.error(err))

クライアント実装2

こちらも同様です。

client_subscribe.js
const grpc = require('@grpc/grpc-js')

const { ItemSubscribeRequest } = require('./generated/proto/item_pb')
const { ItemManageClient } = require('./generated/proto/item_grpc_pb')

const client = new ItemManageClient(
    ・・・
)

const stream = client.subscribe(new ItemSubscribeRequest())

stream.on('data', event => {
    // toObject で JavaScript オブジェクトへ変換
    console.log(`*** received event = ${JSON.stringify(event.toObject())}`)
})

・・・

動作確認

(a) と同じ操作を行った結果は以下のようになりました。

Server 出力内容
> node server.js
start server: 50051
*** subscribed
*** subscribed
*** publish event: {"added":{"itemId":"a1_item-1","price":100}}
*** publish event: {"added":{"itemId":"a1_item-2","price":20}}
*** publish event: {"removed":{"itemId":"a1_item-1"}}
*** publish event: {"removed":{"itemId":"a1_item-2"}}
*** unsubscribed
^C
Client1 出力内容
> node client.js a1
id = a1_item-1, price = 100
id = a1_item-2, price = 20
*** ERROR = 6 ALREADY_EXISTS: exists item
*** ERROR = 5 NOT_FOUND: item not found
Client2-1 出力内容
> node client_subscribe.js
*** received event = {"added":{"itemId":"a1_item-1","price":100}}
*** received event = {"added":{"itemId":"a1_item-2","price":20}}
*** received event = {"removed":{"itemId":"a1_item-1"}}
*** received event = {"removed":{"itemId":"a1_item-2"}}
*** stream error: Error: 13 INTERNAL: Received RST_STREAM with code 2 (Internal server error)
*** stream end
Client2-2 出力内容
> node client_subscribe.js
*** received event = {"added":{"itemId":"a1_item-1","price":100}}
*** received event = {"added":{"itemId":"a1_item-2","price":20}}
*** received event = {"removed":{"itemId":"a1_item-1"}}
*** received event = {"removed":{"itemId":"a1_item-2"}}
^C

(a) と (b) は同一の gRPC サービス(proto ファイル)を実装したものなので当然ですが、(a) と (b) を相互接続しても特に問題はありませんでした。

RLlib を使ってナップサック問題を強化学習2

局所最適に陥っていたと思われる 前回 に対して、以下の改善案 ※ を思いついたので試してみました。

  • より困難な目標を達成した場合に報酬(価値)へボーナスを加算
 ※ 局所最適から脱して、より良い結果を目指す効果を期待

今回のサンプルコードは http://github.com/fits/try_samples/tree/master/blog/20201019/

サンプル1 改良版(ボーナス加算)

単一操作(品物の 1つを -1 or +1 するか何もしない)を行動とした(前回の)サンプル1 にボーナスを加算する処理を加えてみました。

とりあえず、価値の合計が 375(0-1 ナップサック問題としての最適解)を超えた場合に報酬へ +200 するようにしてみます。

前回から、変数や関数名を一部変更していますが、基本的な処理内容に変更はありません。

また、PPOTrainer では episode_reward_mean / vf_clip_param の値が 200 を超えると警告ログを出すようなので(ppo.pywarn_about_bad_reward_scales)、config で vf_clip_param(デフォルト値は 10)の値を変更するようにしています。

sample1_bonus.ipynb
・・・
def next_state(items, state, action):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        state[idx] += (1 if act == 1 else -1)

    return state

def calc_value(items, state, max_weight, burst_value):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_value
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.episode_steps = config["episode_steps"]
        self.burst_reward = config["burst_reward"]
        self.bonus_rules = config["bonus_rules"]
        
        n = self.episode_steps
        
        self.action_space = Discrete(len(self.items) * 2 + 1)
        self.observation_space = Box(low = -n, high = n, shape = (len(self.items), ))
        
        self.reset()

    def reset(self):
        self.current_steps = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = next_state(self.items, self.state, action)
        
        r, _ = calc_value(self.items, self.state, self.max_weight, self.burst_reward)
        reward = r
        
        # 段階的なボーナス加算
        for (v, b) in self.bonus_rules:
            if r > v:
                reward += b
        
        self.current_steps += 1
        done = self.current_steps >= self.episode_steps
        
        return self.state, reward, done, {}

items = [
    [105, 10],
    [74, 7],
    [164, 15],
    [32, 3],
    [235, 22]
]

config = {
    "env": Knapsack, 
    "vf_clip_param": 60,
    "env_config": {
        "items": items, "episode_steps": 10, "max_weight": 35, "burst_reward": -100, 
        # ボーナスの設定
        "bonus_rules": [ (375, 200) ]
    }
}

・・・

trainer = PPOTrainer(config = config)

・・・
# 30回の学習
for _ in range(30):
    r = trainer.train()
    ・・・

・・・

rs = []

# 1000回試行
for _ in range(1000):
    
    s = [0 for _ in range(len(items))]
    r_tmp = config["env_config"]["burst_reward"]

    for _ in range(config["env_config"]["episode_steps"]):
        a = trainer.compute_action(s)
        s = next_state(items, s, a)

        r, w = calc_value(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
        
        r_tmp = max(r, r_tmp)

    rs.append(r_tmp)

collections.Counter(rs)

上記の結果(30回の学習後に 1000回試行してそれぞれの最高値をカウント)は以下のようになりました。

結果
Counter({376: 957, 375: 42, 334: 1})

最適解の 376 が出るようになっており、ボーナスの効果はそれなりにありそうです。

ただし、毎回このような結果になるわけではなく、前回と同じように 375(0-1 ナップサック問題としての最適解)止まりとなる場合もあります。

検証

次に、ナップサック問題の内容を変えて検証してみます。

ここでは、「2.5 ナップサック問題 - 数理システム」の例題を題材として、状態の範囲やボーナスの内容を変えると結果にどのような差が生じるのかを確認します。

ナップサック問題の内容

価値 サイズ
120 10
130 12
80 7
100 9
250 21
185 16

最大容量(サイズ) 65 における最適解は以下の通りです。

価値 770 の組み合わせ(最適解)
3, 0, 2, 0, 1, 0
価値 745 の組み合わせ(0-1 ナップサック問題の最適解)
0, 1, 1, 1, 1, 1

1. 単一操作(品物の 1つを -1 or +1 するか何もしない)

行動は前回の サンプル1 と同様の以下とします。

  • 品物のどれか 1つを -1 or +1 するか、何も変更しない

ここでは、以下のような状態範囲とボーナスを試しました。

状態範囲(品物毎の個数の範囲)
状態タイプ 最小値 最大値
a -10 10
b 0 5
c 0 3
ボーナス定義
ボーナス定義タイプ v > 750 v > 760 v > 765
0 0 0 0
1 100 100 100
2 100 200 400

ボーナスは段階的に加算し、ボーナス定義タイプ 2 で価値が仮に 770 だった場合は、700(100 + 200 + 400)を加算する事にします。

また、状態(品物毎の個数)はその範囲を超えないよう最小値もしくは最大値で止まるようにしました。

なお、ここからは Jupyter Notebook ではなく Python スクリプトとして実行します。

test1.py
import sys
import numpy as np

import gym
from gym.spaces import Discrete, Box

import ray
from ray.rllib.agents.ppo import PPOTrainer

import collections

N = int(sys.argv[1])
EPISODE_STEPS = int(sys.argv[2])
STATE_TYPE = sys.argv[3]
BONUS_TYPE = sys.argv[4]

items = [
    [120, 10],
    [130, 12],
    [80, 7],
    [100, 9],
    [250, 21],
    [185, 16]
]

state_types = {
    "a": (-10, 10),
    "b": (0, 5),
    "c": (0, 3)
}

bonus_types = {
    "0": [],
    "1": [(750, 100), (760, 100), (765, 100)],
    "2": [(750, 100), (760, 200), (765, 400)]
}

vf_clip_params = {
    "0": 800,
    "1": 1100,
    "2": 1500
}

def next_state(items, state, action, state_range):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        v = state[idx] + (1 if act == 1 else -1)
        # 状態が範囲内に収まるように調整
        state[idx] = min(state_range[1], max(state_range[0], v))

    return state

def calc_value(items, state, max_weight, burst_value):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_value
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.episode_steps = config["episode_steps"]
        self.burst_reward = config["burst_reward"]
        self.state_range = config["state_range"]
        self.bonus_rules = config["bonus_rules"]
        
        self.action_space = Discrete(len(self.items) * 2 + 1)
        
        self.observation_space = Box(
            low = self.state_range[0], 
            high = self.state_range[1], 
            shape = (len(self.items), )
        )

        self.reset()

    def reset(self):
        self.current_steps = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = next_state(self.items, self.state, action, self.state_range)
        
        r, _ = calc_value(self.items, self.state, self.max_weight, self.burst_reward)
        reward = r

        for (v, b) in self.bonus_rules:
            if r > v:
                reward += b
        
        self.current_steps += 1
        done = self.current_steps >= self.episode_steps
        
        return self.state, reward, done, {}

config = {
    "env": Knapsack, 
    "vf_clip_param": vf_clip_params[BONUS_TYPE],
    "env_config": {
        "items": items, "max_weight": 65, "burst_reward": -100, 
        "episode_steps": EPISODE_STEPS, 
        "state_range": state_types[STATE_TYPE], 
        "bonus_rules": bonus_types[BONUS_TYPE]
    }
}

ray.init()

trainer = PPOTrainer(config = config)

for _ in range(N):
    r = trainer.train()
    print(f'iter = {r["training_iteration"]}')

print(f'N = {N}, EPISODE_STEPS = {EPISODE_STEPS}, state_type = {STATE_TYPE}, bonus_type = {BONUS_TYPE}')

rs = []

for _ in range(1000):
    s = [0 for _ in range(len(items))]
    r_tmp = config["env_config"]["burst_reward"]

    for _ in range(config["env_config"]["episode_steps"]):
        a = trainer.compute_action(s)
        s = next_state(items, s, a, config["env_config"]["state_range"])

        r, w = calc_value(
            items, s, 
            config["env_config"]["max_weight"], config["env_config"]["burst_reward"]
        )
        
        r_tmp = max(r, r_tmp)

    rs.append(r_tmp)

print( collections.Counter(rs) )

ray.shutdown()
実行例
> python test1.py 50 10 a 0

学習回数 50、1エピソードのステップ数 10 で学習した後、1000回の試行で最も件数の多かった価値を列挙する処理を 3回実施した結果です。(() 内の数値は 1000回の内にその値が最高値だった件数)

結果(学習回数 = 50、エピソードのステップ数 = 10)
状態タイプ ボーナス定義タイプ 状態の最小値 状態の最大値 765 超過時の総ボーナス 1回目 2回目 3回目
a-0 a 0 -10 10 0 735 (994) 735 (935) 735 (916)
a-1 a 1 -10 10 +300 745 (965) 745 (977) 735 (976)
a-2 a 2 -10 10 +700 735 (945) 735 (971) 770 (1000)
b-0 b 0 0 5 0 750 (931) 750 (829) 750 (1000)
b-1 b 1 0 5 +300 765 (995) 765 (998) 750 (609)
b-2 b 2 0 5 +700 765 (1000) 765 (995) 765 (998)
c-0 c 0 0 3 0 750 (998) 750 (996) 750 (1000)
c-1 c 1 0 3 +300 765 (1000) 750 (993) 765 (1000)
c-2 c 2 0 3 +700 765 (999) 765 (1000) 770 (999)

やはり、ボーナスは有効そうですが、状態タイプ a のように状態の範囲が広く、ボーナスの発生頻度が低くなるようなケースでは有効に働かない可能性も高そうです。

ボーナス定義タイプ 2 で最適解の 770 が出るようになっているものの、頻出するようなものでも無く、たまたま学習が上手くいった場合にのみ発生しているような印象でした。

また、b-1 の 3回目で件数が他と比べて低くなっていますが、こちらは学習が(順調に進まずに)足りていない状態だと考えられます。

2. 一括操作(全品物をそれぞれ -1 or 0 or +1 する)

次に、行動を以下のように変えて同じように検証してみます。

  • 全ての品物を対象にそれぞれを -1 or 0 or +1 する

こちらは、ボーナス加算タイプを 1種類追加しました。

状態範囲(品物毎の個数の範囲)
状態タイプ 最小値 最大値
a -10 10
b 0 5
c 0 3
ボーナス加算
ボーナス加算タイプ v > 750 v > 760 v > 765
0 0 0 0
1 100 100 100
2 100 200 400
3 200 400 800

行動の変更に伴い action_spaceBox で定義しています。

test2.py
・・・

state_types = {
    "a": (-10, 10),
    "b": (0, 5),
    "c": (0, 3)
}

bonus_types = {
    "0": [],
    "1": [(750, 100), (760, 100), (765, 100)],
    "2": [(750, 100), (760, 200), (765, 400)],
    "3": [(750, 200), (760, 400), (765, 800)]
}

・・・

def next_state(items, state, action, state_range):
    for i in range(len(action)):
        v = state[i] + round(action[i])
        state[i] = min(state_range[1], max(state_range[0], v))

    return state

・・・

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.episode_steps = config["episode_steps"]
        self.burst_reward = config["burst_reward"]
        self.state_range = config["state_range"]
        self.bonus_rules = config["bonus_rules"]
        # 品物毎の -1 ~ 1
        self.action_space = Box(low = -1, high = 1, shape = (len(self.items), ))
        
        self.observation_space = Box(
            low = self.state_range[0], 
            high = self.state_range[1], 
            shape = (len(self.items), )
        )

        self.reset()

    def reset(self):
        self.current_steps = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = next_state(self.items, self.state, action, self.state_range)
        
        r, _ = calc_value(self.items, self.state, self.max_weight, self.burst_reward)
        reward = r

        for (v, b) in self.bonus_rules:
            if r > v:
                reward += b
        
        self.current_steps += 1
        done = self.current_steps >= self.episode_steps
        
        return self.state, reward, done, {}

・・・

こちらの方法では、学習回数 50 では明らかに足りなかったので 100 にして実施しました。

結果(学習回数 = 100、エピソードのステップ数 = 10)
状態タイプ ボーナス加算タイプ 状態の最小値 状態の最大値 765 超過時の総ボーナス 1回目 2回目 3回目
a-0 a 0 -10 10 0 735 (477) 735 (531) 735 (714)
a-1 a 1 -10 10 +300 735 (689) 735 (951) 745 (666)
a-2 a 2 -10 10 +700 735 (544) 735 (666) 735 (719)
a-3 a 3 -10 10 +1400 745 (633) 735 (735) 735 (875)
b-0 b 0 0 5 0 735 (364) 760 (716) 740 (590)
b-1 b 1 0 5 +300 735 (935) 760 (988) 655 (685)
b-2 b 2 0 5 +700 760 (1000) 735 (310) 770 (963)
b-3 b 3 0 5 +1400 675 (254) 770 (1000) 770 (909)
c-0 c 0 0 3 0 735 (762) 740 (975) 740 (669)
c-1 c 1 0 3 +300 740 (935) 740 (842) 735 (963)
c-2 c 2 0 3 +700 770 (999) 770 (1000) 715 (508)
c-3 c 3 0 3 +1400 770 (1000) 770 (1000) 770 (1000)

学習の足りていない所が散見されますが(学習も不安定)、特定のタイプで最適解の 770 が割と頻繁に出るようになりました。

ただ、c-3 の場合でも 770 が出やすくなっているものの、確実にそのように学習するわけではありませんでした。

結局のところ、状態・行動・報酬の設計次第という事かもしれません。

RLlib を使ってナップサック問題を強化学習

ナップサック問題強化学習を適用すると、どうなるのか気になったので試してみました。

強化学習には、Ray に含まれている RLlib を使い、Jupyter Notebook 上で実行します。

今回のサンプルコードは http://github.com/fits/try_samples/tree/master/blog/20200922/

はじめに

以下のようにして Ray と RLlib をインストールしておきます。(TensorFlow も事前にインストールしておく)

Ray インストール
> pip install ray[rllib]

ナップサック問題

今回は、以下のような価値と重さを持った品物に対して、重さの合計が 35 以下で価値の合計を最大化する品物の組み合わせを探索する事にします。

価値 重さ
105 10
74 7
164 15
32 3
235 22

品物をそれぞれ 1個までしか選べない場合(0-1 ナップサック問題)の最適な組み合わせは、以下のように 5番目以外を 1個ずつ選ぶ事です。(価値の合計は 375

価値 375 の組み合わせ(0-1 ナップサック問題の最適解)
1, 1, 1, 1, 0

また、同じ品物をいくらでも選べる場合の最適な組み合わせは以下のようになります。(価値の合計は 376

価値 376 の組み合わせ
0, 2, 1, 2, 0

強化学習でこのような組み合わせを導き出す事ができるのか確認します。

1. サンプル1 - sample1.ipynb

とりあえず、強化学習における状態・行動・報酬を以下のようにしてみました。 エピソードは、指定した回数(今回は 10回)の行動を行う事で終了とします。

状態 行動 (即時)報酬
品物毎の個数 品物の個数を操作(-1, +1) 価値の合計

行動は以下のような 0 ~ 10 の数値で表現する事にします。

  • 0 = 1番目の品物の個数を -1
  • 1 = 1番目の品物の個数を +1
  • ・・・
  • 8 = 5番目の品物の個数を -1
  • 9 = 5番目の品物の個数を +1
  • 10 = 個数を変更しない(現状維持)

これらを OpenAI Gym で定義したのが次のコードです。

環境は gym.Env を継承し、__init__ で行動空間(action_space)と状態空間(observation_space)を定義、reset で状態の初期化、step で状態の更新と報酬の算出を行うように実装します。

step の戻り値は、更新後の状態報酬エピソード終了か否か(デバッグ用途等の)情報 となっています。

Discrete(n) は 0 ~ n - 1 の整数値、Box は low ~ high の実数値の多次元配列となっており、、行動空間と状態空間の定義にそれぞれ使用しています。

環境定義
import gym
from gym.spaces import Discrete, Box

def next_state(items, state, action):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        state[idx] += (1 if act == 1 else -1)

    return state

# 報酬の算出
def calc_reward(items, state, max_weight, burst_reward):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_reward
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        # 重さの上限値
        self.max_weight = config["max_weight"]
        # 行動の回数
        self.max_count = config["max_count"]
        # 重さが超過するか、個数が負の数となった場合の報酬
        self.burst_reward = config["burst_reward"]
        
        n = self.max_count
        
        # 行動空間の定義
        self.action_space = Discrete(len(self.items) * 2 + 1)
        # 状態空間の定義
        self.observation_space = Box(low = -n, high = n, shape = (len(self.items), ))
        
        self.reset()

    def reset(self):
        self.count = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        # 状態の更新
        self.state = next_state(self.items, self.state, action)
        # 報酬の算出
        reward, _ = calc_reward(self.items, self.state, self.max_weight, self.burst_reward)
        
        self.count += 1
        # エピソード完了の判定
        done = self.count >= self.max_count
        
        return self.state, reward, done, {}

次に上記環境のための設定を行います。 env_config の内容が __init__ の config 引数となります。

重さの上限値を超えた場合や個数が負の数となった場合の報酬(burst_reward)をとりあえず -100 としています。

なお、基本的に RLlib のデフォルト設定値を使う事にします。

設定
items = [
    [105, 10],
    [74, 7],
    [164, 15],
    [32, 3],
    [235, 22]
]

config = {
    "env": Knapsack, 
    "env_config": {"items": items, "max_count": 10, "max_weight": 35, "burst_reward": -100}
}

強化学習を実施する前に、Ray を初期化しておきます。

Ray 初期化
import ray

ray.init()

(a) PPO(Proximal Policy Optimization)

PPO アルゴリズムを試してみます。

トレーナーの定義 - PPO
from ray.rllib.agents.ppo import PPOTrainer

trainer = PPOTrainer(config = config)

まずは、学習(train の実行)を 10回繰り返してみます。

後で学習時の状況を確認するために episode_reward_max 等の値を保持するようにしています。

学習
r_max = []
r_min = []
r_mean = []
from ray.tune.logger import pretty_print

for _ in range(10):
    r = trainer.train()
    print(pretty_print(r))

    r_max.append(r["episode_reward_max"])   # 最大
    r_min.append(r["episode_reward_min"])   # 最小
    r_mean.append(r["episode_reward_mean"]) # 平均

以下のコードで結果を確認してみます。 compute_action を呼び出す事で、指定した状態に対する行動を取得できます。

評価1
s = [0 for _ in range(len(items))]

for _ in range(config["env_config"]["max_count"]):
    a = trainer.compute_action(s)
    
    s = next_state(items, s, a)
    
    r, w = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
    
    print(f"{a}, {s}, {r}, {w}")

下記のように、価値の合計が 375 となる組み合わせ(0-1 ナップサック問題とした場合の最適解)が現れており、ある程度は学習できているように見えます。

評価1の結果 - PPO 学習 10回
1, [1, 0, 0, 0, 0], 105, 10
3, [1, 1, 0, 0, 0], 179, 17
5, [1, 1, 1, 0, 0], 343, 32
7, [1, 1, 1, 1, 0], 375, 35
7, [1, 1, 1, 2, 0], -100, 38
6, [1, 1, 1, 1, 0], 375, 35
0, [0, 1, 1, 1, 0], 270, 25
1, [1, 1, 1, 1, 0], 375, 35
6, [1, 1, 1, 0, 0], 343, 32
7, [1, 1, 1, 1, 0], 375, 35

これだけだとよく分からないので、100回繰り返してそれぞれの報酬の最高値をカウントしてみます。

評価2
import collections

rs = []

for _ in range(100):
    
    s = [0 for _ in range(len(items))]
    r_tmp = config["env_config"]["burst_reward"]

    for _ in range(config["env_config"]["max_count"]):
        a = trainer.compute_action(s)
        s = next_state(items, s, a)

        r, w = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
        
        r_tmp = max(r, r_tmp)

    rs.append(r_tmp)

collections.Counter(rs)

結果は以下のようになりました。 最高値の 376 ではなく 375 の組み合わせへ向かうように学習が進んでいるように見えます。

評価2の結果 - PPO 学習 10回
Counter({302: 6,
         309: 2,
         284: 1,
         343: 12,
         375: 39,
         270: 4,
         341: 7,
         373: 1,
         340: 2,
         269: 1,
         372: 6,
         301: 4,
         376: 1,
         334: 2,
         333: 2,
         344: 1,
         299: 4,
         317: 1,
         275: 1,
         349: 1,
         316: 1,
         312: 1})

更に、学習を 10回繰り返した後の結果です。 375 へ向かって収束しているように見えます。

評価1の結果 - PPO 学習 20回
1, [1, 0, 0, 0, 0], 105, 10
5, [1, 0, 1, 0, 0], 269, 25
3, [1, 1, 1, 0, 0], 343, 32
7, [1, 1, 1, 1, 0], 375, 35
10, [1, 1, 1, 1, 0], 375, 35
10, [1, 1, 1, 1, 0], 375, 35
10, [1, 1, 1, 1, 0], 375, 35
2, [1, 0, 1, 1, 0], 301, 28
3, [1, 1, 1, 1, 0], 375, 35
0, [0, 1, 1, 1, 0], 270, 25
評価2の結果 - PPO 学習 20回
Counter({375: 69, 373: 18, 376: 3, 341: 5, 372: 1, 302: 1, 344: 3})

50回学習した後の結果は以下のようになりました。

評価2の結果 - PPO 学習 50回
Counter({375: 98, 373: 2})

ここまでの(学習時の)報酬状況をグラフ化してみます。

学習回数と報酬のグラフ化
%matplotlib inline

import matplotlib.pyplot as plt

plt.plot(r_max, label = "reward_max", color = "red")
plt.plot(r_min, label = "reward_min", color = "green")
plt.plot(r_mean, label = "reward_mean", color = "blue")

plt.legend(loc = "upper left")

plt.ylim([-1000, 3700])
plt.ylabel("reward")

plt.show()
学習時の報酬グラフ - PPO 学習 50回

f:id:fits:20200922223349p:plain

エピソード内の報酬を高めていくには重量超過やマイナス個数を避けるのが重要、それには各品物の個数を 0 か 1 にしておくのが無難なため、0-1 ナップサック問題としての最適解へ向かっていくのかもしれません。

(b) DQN(Deep Q-Network)

比較のために DQN でも実行してみます。 PPOTrainer の代わりに DQNTrainer を使うだけです。

トレーナーの定義 - DQN
from ray.rllib.agents.dqn import DQNTrainer

trainer = DQNTrainer(config = config)

10回の学習では以下のような結果となりました。

評価1の結果 - DQN 学習 10回
5, [0, 0, 1, 0, 0], 164, 15
7, [0, 0, 1, 1, 0], 196, 18
2, [0, -1, 1, 1, 0], -100, 11
5, [0, -1, 2, 1, 0], -100, 26
2, [0, -2, 2, 1, 0], -100, 19
0, [-1, -2, 2, 1, 0], -100, 9
1, [0, -2, 2, 1, 0], -100, 19
3, [0, -1, 2, 1, 0], -100, 26
5, [0, -1, 3, 1, 0], -100, 41
6, [0, -1, 3, 0, 0], -100, 38
評価2の結果 - DQN 学習 10回
Counter({-100: 33,
         360: 1,
         105: 8,
         270: 2,
         372: 5,
         315: 1,
         0: 5,
         374: 1,
         74: 2,
         274: 1,
         235: 7,
         340: 3,
         32: 4,
         164: 5,
         238: 1,
         309: 2,
         106: 1,
         228: 2,
         169: 1,
         267: 3,
         343: 1,
         196: 1,
         331: 1,
         363: 1,
         254: 1,
         299: 1,
         317: 1,
         212: 1,
         301: 1,
         328: 1,
         138: 1,
         64: 1})

DQN は PPO に比べて episodes_total(学習で実施したエピソード数の合計)が 1/4 程度と少なかったので、40回まで実施してみました。

評価1の結果 - DQN 学習 40回
6, [0, 0, 0, -1, 0], -100, -3
7, [0, 0, 0, 0, 0], 0, 0
1, [1, 0, 0, 0, 0], 105, 10
4, [1, 0, -1, 0, 0], -100, -5
4, [1, 0, -2, 0, 0], -100, -20
8, [1, 0, -2, 0, -1], -100, -42
7, [1, 0, -2, 1, -1], -100, -39
1, [2, 0, -2, 1, -1], -100, -29
2, [2, -1, -2, 1, -1], -100, -36
4, [2, -1, -3, 1, -1], -100, -51
評価2の結果 - DQN 学習 40回
Counter({0: 5,
         328: 4,
         164: 8,
         -100: 28,
         340: 5,
         228: 1,
         309: 4,
         74: 1,
         238: 4,
         235: 7,
         32: 3,
         358: 2,
         343: 2,
         196: 2,
         269: 3,
         372: 2,
         365: 1,
         179: 2,
         374: 1,
         302: 1,
         106: 3,
         312: 1,
         105: 5,
         270: 2,
         148: 2,
         360: 1})

80回学習した結果です。

評価1 - DQN 学習 80回
5, [0, 0, 1, 0, 0], 164, 15
1, [1, 0, 1, 0, 0], 269, 25
7, [1, 0, 1, 1, 0], 301, 28
4, [1, 0, 0, 1, 0], 137, 13
0, [0, 0, 0, 1, 0], 32, 3
2, [0, -1, 0, 1, 0], -100, -4
7, [0, -1, 0, 2, 0], -100, -1
1, [1, -1, 0, 2, 0], -100, 9
0, [0, -1, 0, 2, 0], -100, -1
4, [0, -1, -1, 2, 0], -100, -16
評価2 - DQN 学習 80回
Counter({-100: 5,
         269: 8,
         164: 20,
         301: 3,
         211: 2,
         235: 7,
         105: 3,
         372: 2,
         333: 1,
         238: 2,
         316: 2,
         196: 2,
         253: 1,
         242: 2,
         312: 3,
         343: 3,
         340: 9,
         179: 2,
         267: 3,
         270: 1,
         74: 3,
         328: 9,
         309: 2,
         284: 1,
         137: 1,
         285: 1,
         360: 1,
         374: 1})

200回学習した結果です。

評価2 - DQN 学習 200回
Counter({-100: 8,
         238: 4,
         340: 10,
         0: 4,
         267: 6,
         106: 1,
         74: 7,
         105: 1,
         328: 4,
         235: 33,
         309: 9,
         228: 1,
         270: 2,
         32: 1,
         301: 1,
         196: 3,
         341: 2,
         269: 1,
         374: 1,
         372: 1})

学習時の報酬グラフは以下の通り、PPO のようにスムーズに学習が進んでおらず、DQN は本件に向いていないのかもしれません。

学習時の報酬グラフ - DQN 学習 200回

f:id:fits:20200922223834p:plain

これは、報酬のクリッピング ※ に因るものかとも思いましたが、RLlib における報酬クリッピングの設定 clip_rewards はデフォルトで None であり、DQN のデフォルト設定 (dqn.py の DEFAULT_CONFIG) においても有効化しているような箇所は見当たりませんでした。

他の箇所で実施しているのかもしれませんが、今回は確認できませんでした。

※ 基本的には、
   元の報酬の符号に合わせて 1, -1, 0 のいずれかに報酬の値が変更されてしまい、
   報酬の大小は考慮されなくなる

   本件であれば、
   重量超過やマイナス個数のように報酬がマイナスにならない行動であれば
   どれでも良い事になってしまうと思われる

ちなみに、clip_rewardsTrue に設定した PPOTrainer を使ってみたところ、結果が著しく悪化しました。(マイナス報酬を避けるだけになった)

2. サンプル2 - sample2.ipynb

次に、行動によってのみ次の状態を決定するようにしてみます。 エピソードは 1回の行動で終了とします。

状態 行動 (即時)報酬
品物毎の個数 品物毎の個数 価値の合計

この場合、action_spaceBox で定義する事になります。 最軽量の品物が重要制限内で選択できる最大の個数を Box の最大値としています。

環境定義
def next_state(action):
    return [round(d) for d in action]

def calc_reward(items, state, max_weight, burst_reward):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_reward
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.burst_reward = config["burst_reward"]
        # 個数の最大値
        n = self.max_weight // min(np.array(self.items)[:, 1])
        
        self.action_space = Box(0, n, shape = (len(self.items), ))
        self.observation_space = Box(0, n, shape = (len(self.items), ))

    def reset(self):
        return [0 for _ in self.items]

    def step(self, action):
        state = next_state(action)
        
        reward, _ = calc_reward(self.items, state, self.max_weight, self.burst_reward)
        
        return state, reward, True, {}
設定
items = [
    [105, 10],
    [74, 7],
    [164, 15],
    [32, 3],
    [235, 22]
]

config = {
    "env": Knapsack, 
    "env_config": {"items": items, "max_weight": 35, "burst_reward": -100}
}

(a) PPO(Proximal Policy Optimization)

PPO で実施してみます。

基本的な処理内容は 1. サンプル1 と同じですが、行動 1回でエピソードが終了するため、以下のコードで結果を確認する事にします。

評価
import collections

rs = []

for _ in range(100):
    s = [0 for _ in range(len(items))]
    a = trainer.compute_action(s)
    
    s = next_state(a)
    
    r, _ = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
    
    rs.append(r)

collections.Counter(rs)

70回学習後の結果です。

評価結果 - PPO 学習 70回
Counter({375.0: 100})

学習時の状況は以下のようになりました。

学習時の報酬グラフ - PPO 学習 70回

f:id:fits:20200922224006p:plain

(b) DDPG(Deep Deterministic Policy Gradient)

action_space が Box の場合に、DQNTrainer が UnsupportedSpaceException エラーを発生させるので、DQN は使えませんでした。

そこで、DDPG を使ってみました。

トレーナーの定義 - DDPG
from ray.rllib.agents.ddpg import DDPGTrainer

trainer = DDPGTrainer(config = config)

こちらは、PPO 等と比べて処理に時間がかかる上に、80回学習しても順調とはいえない結果となりました。

学習時の報酬グラフ - DDPG 学習 80回

f:id:fits:20200922224049p:plain

評価結果 - DDPG 学習 80回
Counter({347.0: 18, -100: 79, 242.0: 3})

以下のように評価の回数を 1000 にして再度確認してみます。

for _ in range(1000):
    s = [0 for _ in range(len(items))]
    a = trainer.compute_action(s)
    
    ・・・

collections.Counter(rs)
評価結果(1000回) - DDPG 学習 80回
Counter({347.0: 232,
         -100: 718,
         242.0: 24,
         315.0: 14,
         210.0: 1,
         316.0: 7,
         274.0: 3,
         374.0: 1})

347(3, 0, 0, 1, 0)が多くなっているのはよく分かりませんが、DDPG も本件に向いていないのかもしれません。

Deno から npm パッケージを使用する(Deno で fp-ts)

下記の方法を用いて Node.js / ブラウザ用 npm パッケージを Deno から利用してみました。

npm パッケージは関数プログラミング用 TypeScript ライブラリの fp-ts を試すことにします。

fp-ts は CommonJS と ES Modules のモジュール形式に対応していますが、現時点で Deno に対する直接的なサポートは無さそうでした。

また、使用した Deno のバージョンは以下の通りです。

  • Deno 1.3.1

今回のサンプルコードは http://github.com/fits/try_samples/tree/master/blog/20200825/

はじめに

まずは、Node.js でサンプルコードを作成してみました。

sample.ts
import { Option, some, none, map } from 'fp-ts/lib/Option'
import { pipe } from 'fp-ts/lib/pipeable'

const f = (d: Option<number>) => 
    pipe(
        d,
        map(v => v + 2),
        map(v => v * 3)
    )

console.log( f(some(5)) )
console.log( f(none) )

実行結果は以下の通りです。

sample.ts 実行結果
> npm install ts-node typescript fp-ts
・・・

> ts-node sample.ts
{ _tag: 'Some', value: 21 }
{ _tag: 'None' }

(a) Skypack の使用

前回の GraphQL.js でも利用しましたが、Skypack は npm パッケージをブラウザから直接使えるようにするための CDN となっています。

CommonJS 形式の npm パッケージを ES Modules 形式で提供する機能や Deno をサポートする機能(https://docs.skypack.dev/code/deno)が用意されているようです。

型情報なし

まずは、fp-ts の ES Modules のファイル(es6 に配置されている)を Skypack から import してみます。

TypeScript 用の型定義を指定しないと、関数の引数や戻り値などは any 型として扱う事になります。

また、import の際に拡張子を指定しなくても Skypack が .js ファイルを返してくれます。

a_1.ts
import { some, none, map } from 'https://cdn.skypack.dev/fp-ts/es6/Option.js'
// 以下でも同じ
//import { some, none, map } from 'https://cdn.skypack.dev/fp-ts/es6/Option'
import { pipe } from 'https://cdn.skypack.dev/fp-ts/es6/pipeable.js'

const f = (d: any) => 
    // @ts-ignore
    pipe(
        d,
        map( (v: number) => v + 2 ),
        map( (v: number) => v * 3 )
    )

console.log( f(some(5)) )
console.log( f(none) )
a_1.ts 実行結果
> deno run a_1.ts
・・・
{ _tag: "Some", value: 21 }
{ _tag: "None" }

ここで、@ts-ignoreVisual Studio Code におけるエラー表示対策 ※ のために付けています。

 ※ pipe 関数は、10個の any 型の引数をとる関数となっているが、
    引数を 3つしか指定していない事に対するエラー
Visual Studio Code におけるエラー表示例(@ts-ignore を付けなかった場合)

f:id:fits:20200825214530p:plain

なお、関数の引数や戻り値などを適切な型で扱うには、型定義ファイル(.d.ts)の指定が必要になりますが、fp-ts が用意している型定義ファイルを以下のように @deno-types で指定しても上手くいきません。

型定義ファイル指定の失敗例
// @deno-types="https://cdn.skypack.dev/fp-ts/es6/Option.d.ts"
import { Option, some, none, map } from 'https://cdn.skypack.dev/fp-ts/es6/Option.js'

// @deno-types="https://cdn.skypack.dev/fp-ts/es6/pipeable.d.ts"
import { pipe } from 'https://cdn.skypack.dev/fp-ts/es6/pipeable.js'

・・・

というのも、fp-ts の Option.d.tspipeable.d.ts では import { ・・・ } from './Alt' のように .d.ts の拡張子を付けずに他の型定義を import しており、不都合が生じます。※

 ※ この場合、Skypack は Alt.js を返すことになり、
    型情報を正しく取得できないと考えられる

ちなみに、Skypack 本来の使い方としては、以下のようにパッケージのルートを指定して import する事になりそうです。

a_2.ts
import { option, pipeable } from 'https://cdn.skypack.dev/fp-ts'

const { some, none, map } = option
const { pipe } = pipeable

const f = (d: any) => 
    // @ts-ignore
    pipe(
        d,
        map( (v: number) => v + 2 ),
        map( (v: number) => v * 3 )
    )

console.log( f(some(5)) )
console.log( f(none) )
a_2.ts 実行結果
> deno run a_2.ts
・・・
{ _tag: "Some", value: 21 }
{ _tag: "None" }

型定義を自作

型に関しては、型定義ファイルを自作する事で一応は解決できます。

例えば、以下のような型定義ファイルを用意します。

types/Option.d.ts
export interface None {
    readonly _tag: 'None'
}

export interface Some<A> {
    readonly _tag: 'Some'
    readonly value: A
}
export declare type Option<A> = None | Some<A>

export declare const some: <A>(a: A) => Option<A>
export declare const none: Option<never>
export declare const map: <A, B>(f: (a: A) => B) => (fa: Option<A>) => Option<B>
types/pipeable.d.ts
export declare function pipe<A, B, C, D, E, F, G, H, I, J>(
    a: A,
    ab: (a: A) => B,
    bc?: (b: B) => C,
    cd?: (c: C) => D,
    de?: (d: D) => E,
    ef?: (e: E) => F,
    fg?: (f: F) => G,
    gh?: (g: G) => H,
    hi?: (h: H) => I,
    ij?: (i: I) => J
): J

これを @deno-types で指定する事で型の問題が解決します。

a_3.ts
// @deno-types="./types/Option.d.ts"
import { Option, some, none, map } from 'https://cdn.skypack.dev/fp-ts/es6/Option.js'
// @deno-types="./types/pipeable.d.ts"
import { pipe } from 'https://cdn.skypack.dev/fp-ts/es6/pipeable.js'

const f = (d: Option<number>) => 
    pipe(
        d,
        map( v => v + 2 ),
        map( v => v * 3 )
    )

console.log( f(some(5)) )
console.log( f(none) )
a_3.ts 実行結果
> deno run a_3.ts
・・・
{ _tag: "Some", value: 21 }
{ _tag: "None" }

パッケージのルートを import する場合は、以下のような型定義ファイルを追加して @deno-types で指定します。

types/index.d.ts
import * as option from './Option.d.ts'
import * as pipeable from './pipeable.d.ts'

export {
    option,
    pipeable
}
a_4.ts
// @deno-types="./types/index.d.ts"
import { option, pipeable } from 'https://cdn.skypack.dev/fp-ts'

const { some, none, map } = option
const { pipe } = pipeable

const f = (d: option.Option<number>) => 
    pipe(
        d,
        map( v => v + 2 ),
        map( v => v * 3 )
    )

console.log( f(some(5)) )
console.log( f(none) )
a_4.ts 実行結果
> deno run a_4.ts
・・・
{ _tag: "Some", value: 21 }
{ _tag: "None" }

dts クエリパラメータの利用

Skypack には、Deno 用に型定義ファイルを解決する手段として dts クエリパラメータが用意されています。

これを使う事で、本来は以下のようなコードで型問題を解決できるはずですが、fp-ts 2.8.2 では上手くいきませんでした。

a_5e.ts
import { option, pipeable } from 'https://cdn.skypack.dev/fp-ts?dts'

const { some, none, map } = option
const { pipe } = pipeable

const f = (d: option.Option<number>) => 
    pipe(
        d,
        map( v => v + 2 ),
        map( v => v * 3 )
    )

console.log( f(some(5)) )
console.log( f(none) )

実行結果は以下のようになり、型定義ファイルの取得途中で 404 Not Found エラーが発生してしまいます。

a_5e.ts 実行結果
> deno run a_5e.ts
・・・
Download https://cdn.skypack.dev/-/fp-ts@v2.8.2-Hr9OPgW5wz4u6TqOfiZH/dist=es2020,mode=types/lib/TaskEither.d.ts
error: Import 'https://cdn.skypack.dev/-/fp-ts@v2.8.2-Hr9OPgW5wz4u6TqOfiZH/dist=es2020,mode=types/lib/HKT.d.ts' failed: 404 Not Found
Imported from "https://cdn.skypack.dev/-/fp-ts@v2.8.2-Hr9OPgW5wz4u6TqOfiZH/dist=es2020,mode=types/lib/index.d.ts:42"

これは、fp-ts の中で HKT だけ特殊な扱いがされており、HKT.d.ts ファイルが lib ディレクトリ内に配置されておらず、パッケージのルートディレクトリに配置されている事が原因だと考えられます。※

 ※ そのため、
    "/lib/HKT.d.ts" ではなく "/HKT.d.ts" を import する必要がある

    Node.js においては
    "/lib/HKT/package.json" の typings フィールドの値から
    HKT.d.ts の配置場所を取得するようになっていると思われるが、
    Skypack の dts クエリパラメータの機能では
    そこまでを考慮していない事が原因だと考えられる

ここで、dts クエリパラメータによって何が変わるのかというと、?dts を付けることでレスポンスヘッダーに x-typescript-types が付与され、型定義ファイル(.d.ts)の取得先が提示されるようになります。※

 ※ Deno は x-typescript-types ヘッダーの値から
    型定義ファイルを自動的に取得するようになっている
dts クエリパラメータの有無による違い
$ curl --head -s https://cdn.skypack.dev/fp-ts | grep x-typescript-types
$ curl --head -s https://cdn.skypack.dev/fp-ts?dts | grep x-typescript-types
x-typescript-types: /-/fp-ts@v2.8.2-Hr9OPgW5wz4u6TqOfiZH/dist=es2020,mode=types/lib/index.d.ts

更に、x-typescript-types のパスから取得できる index.d.ts は、fp-ts のオリジナル index.d.ts を(Deno 用に)加工したもの ※ となっています。

 ※ 型定義ファイル内の
    import 対象のパスに拡張子 .d.ts を加えている

    fp-ts の HKT で起きた問題を考えると、
    現時点では .d.ts ファイルの存在有無や
    package.json の typings フィールド等の考慮はされていないと考えられる
オリジナルの index.d.ts 内容
・・・
import * as alt from './Alt'
・・・
x-typescript-types の index.d.ts 内容
・・・
import * as alt from './Alt.d.ts'
・・・

ついでに GraphQL.js に対して確認してみましたが、こちらは(npm パッケージ内に .d.ts ファイルが存在するものの) ?dts を付けても x-typescript-types ヘッダーは付与されませんでした。

GraphQL.js の場合
$ curl --head -s https://cdn.skypack.dev/graphql?dts | grep x-typescript-types

(b) Deno Node compatibility の使用

Deno Node compatibility は Deno の std ライブラリに含まれている機能で CommonJS モジュールのロードだけではなく、Node.js API との互換 API もある程度は用意されているようです。(fs.readFile 等)

ただ、require の戻り値の型は any となってしまうので、TypeScript で使うメリットはあまり無いかもしれません。

b_1.ts
import { createRequire } from 'https://deno.land/std/node/module.ts'

const require = createRequire(import.meta.url)

const { some, none, map } = require('fp-ts/lib/Option')
const { pipe } = require('fp-ts/lib/pipeable')

const f = (d: any) => 
    pipe(
        d,
        map( (v: number) => v + 2),
        map( (v: number) => v * 3)
    )

console.log( f(some(5)) )
console.log( f(none) )

fp-ts を npm でインストールしてから deno run で実行します。 deno run で実行するには下記オプションを指定する必要がありました。

  • --unstable
  • --allow-read
  • --allow-env
b_1.ts 実行結果
> npm install fp-ts
・・・

> deno run --unstable --allow-read --allow-env b_1.ts
・・・
{ _tag: "Some", value: 21 }
{ _tag: "None" }

Deno で GraphQL

GraphQL を Deno で試してみました。

https://deno.land/x に Deno 用の GraphQL モジュールがいくつかありましたが(基本的には GraphQL.js のポーティング)、ここでは GraphQL.js を直接使う事にします。

今回のサンプルコードは http://github.com/fits/try_samples/tree/master/blog/20200817/

1. GraphQL.js の使用

GraphQL.js を Deno から使うには、以下のように SkypackPika CDN の後継)という CDN から import するのが簡単そうです。

import { ・・・ } from 'https://cdn.skypack.dev/graphql'

2. 単独実行

GraphQL.js を使って GraphQL の Query・Mutation・Subscription を一通り実行してみます。

まずは、buildSchema 関数に GraphQL の型定義を渡して GraphQLSchema を取得します。

これをクエリや GraphQL 処理の実装と共に graphql 関数や subscribe 関数へ渡す事で処理結果を得ます。

graphql 関数の方は第二引数に文字列を使えましたが(型は Source | string)、subscribe 関数の方は使えなかったので(型は DocumentNode)、こちらは parse した結果を渡すようにしています。

GraphQL 処理の実装(下記サンプルの root 変数の箇所)では、Query・Mutation・Subscription の処理に応じた関数を用意し、型定義に応じた値を返すように実装します。

Subscription の場合は、([Symbol.asyncIterator] プロパティで AsyncIterator を返す)AsyncIterable を戻り値としなければならないようなので、下記では MessageBox クラスとして実装しています。

sample1.js
import {
    graphql, buildSchema, subscribe, parse
} from 'https://cdn.skypack.dev/graphql'

import { v4 } from 'https://deno.land/std/uuid/mod.ts'

// GraphQL 型定義
const schema = buildSchema(`
    enum Category {
        Standard
        Extra
    }

    input CreateItem {
        category: Category!
        value: Int!
    }

    type Item {
        id: ID!
        category: Category!
        value: Int!
    }

    type Mutation {
        create(input: CreateItem!): Item
    }

    type Query {
        find(id: ID!): Item
    }

    type Subscription {
        created: Item
    }
`)

// Subscription 用の AsyncIterable
class MessageBox {
    #messages = []
    #resolves = []

    publish(value) {
        const resolve = this.#resolves.shift()

        if (resolve) {
            resolve({ value })
            // 以下でも可
            // resolve({ value, done: false })
        }
        else {
            this.#messages.push(value)
        }
    }

    [Symbol.asyncIterator]() {
        return {
            next: () => {
                console.log('** asyncIterator next')

                return new Promise(resolve => {
                    const value = this.#messages.shift()

                    if (value) {
                        resolve({ value })
                        // 以下でも可
                        // resolve({ value, done: false })
                    }
                    else {
                        this.#resolves.push(resolve)
                    }
                })
            }
        }
    }
}

const store = {}
const box = new MessageBox()

// GraphQL 処理の実装
const root = {
    create: ({ input: { category, value } }) => {
        console.log(`* call create: category = ${category}, value = ${value}`)

        const id = `item-${v4.generate()}`
        const item = { id, category, value }

        store[id] = item
        box.publish({ created: item })

        return item
    },
    find: ({ id }) => {
        console.log(`* call find: ${id}`)
        return store[id]
    },
    // Subscription の実装では戻り値を AsyncIterable とする
    created: () => box
}

const run = async () => {
    const m1 = `
        mutation {
            create(input: { category: Standard, value: 10 }) {
                id
            }
        }
    `
    // Mutation の実行1
    const mr1 = await graphql(schema, m1, root)
    console.log(mr1)

    const m2 = `
        mutation Create($p: CreateItem!) {
            create(input: $p) {
                id
            }
        }
    `

    const vars = {
        p: {
            category: 'Extra',
            value: 123
        }
    }
    // Mutation の実行2(変数利用)
    const mr2 = await graphql(schema, m2, root, null, vars)
    console.log(mr2)

    const q = `
        {
            find(id: "${mr2.data.create.id}") {
                id
                category
                value
            }
        }
    `
    // Query の実行
    const qr = await graphql(schema, q, root)
    console.log(qr)

    const s = parse(`
        subscription {
            created {
                id
                category
            }
        }
    `)
    // Subscription 処理
    const subsc = await subscribe(schema, s, root)

    for await (const r of subsc) {
        console.log('*** subscribe')
        console.log(r)
    }
}

run().catch(err => console.error(err))

実行結果は以下の通りです。

実行結果
> deno run sample1.js

・・・
* call create: category = Standard, value = 10
{ data: { create: { id: "item-11ca8326-832b-4e13-9b47-d2c70c1e95e9" } } }
* call create: category = Extra, value = 123
{ data: { create: { id: "item-29f1851b-4cbb-4f23-872f-11856f1c0bf7" } } }
* call find: item-29f1851b-4cbb-4f23-872f-11856f1c0bf7
{
  data: {
    find: { id: "item-29f1851b-4cbb-4f23-872f-11856f1c0bf7", category: "Extra", value: 123 }
  }
}
** asyncIterator next
*** subscribe
{
  data: {
    created: { id: "item-11ca8326-832b-4e13-9b47-d2c70c1e95e9", category: "Standard" }
  }
}
** asyncIterator next
*** subscribe
{
  data: {
    created: { id: "item-29f1851b-4cbb-4f23-872f-11856f1c0bf7", category: "Extra" }
  }
}
** asyncIterator next

3. Web サーバー化

Deno の http を使って、上記処理を Web サーバー化してみました。

sample2.js
import {
    graphql, buildSchema, subscribe, parse
} from 'https://cdn.skypack.dev/graphql'

import { v4 } from 'https://deno.land/std/uuid/mod.ts'
import { serve } from 'https://deno.land/std@0.65.0/http/server.ts'

const schema = buildSchema(`
    enum Category {
        Standard
        Extra
    }

    ・・・
`)

class MessageBox {
    ・・・
}

const store = {}
const box = new MessageBox()

const root = {
    create: ({ input: { category, value } }) => {
        const id = `item-${v4.generate()}`
        const item = { id, category, value }

        store[id] = item
        box.publish({ created: item })

        return item
    },
    find: ({ id }) => {
        return store[id]
    },
    created: () => box
}

// Web サーバー処理
const run = async () => {
    const server = serve({ port: 8080 })

    // リクエスト毎の処理
    for await (const req of server) {
        const buf = await Deno.readAll(req.body)
        const query = new TextDecoder().decode(buf)

        console.log(`* query: ${query}`)

        const res = await graphql(schema, query, root)

        req.respond({ body: JSON.stringify(res) })
    }
}

const runSubscribe = async () => {
    const s = parse(`
        subscription {
            created {
                id
                category
                value
            }
        }
    `)

    const subsc = await subscribe(schema, s, root)

    for await (const r of subsc) {
        console.log(`*** subscribe: ${JSON.stringify(r)}`)
    }
}

Promise.all([
    run(),
    runSubscribe()
]).catch(err => console.error(err))

実行

deno run で Web サーバーを起動します。 --allow-net でネットワーク処理を許可する必要があります。

Web サーバー起動
> deno run --allow-net sample2.js

create を実行した結果です。

Mutation の実行
$ curl -s http://localhost:8080 -d 'mutation { create(input: { category: Extra, value: 5 }) { id } }'

{"data":{"create":{"id":"item-7fc5e3bb-286f-48fe-b027-9ca34dcc6451"}}}

create で返された id を指定して find した結果です。

Query の実行1
$ curl -s http://localhost:8080 -d '{ find(id: "item-7fc5e3bb-286f-48fe-b027-9ca34dcc6451") { id category value } }'

{"data":{"find":{"id":"item-7fc5e3bb-286f-48fe-b027-9ca34dcc6451","category":"Extra","value":5}}}

存在しない id を指定して find した結果です。

Query の実行2
$ curl -s http://localhost:8080 -d '{ find(id: "item-invalid") { id category value } }'

{"data":{"find":null}}

サーバー側の出力結果は以下の通りです。

Web サーバー出力結果
> deno run --allow-net sample2.js

・・・
** asyncIterator next
* query: mutation { create(input: { category: Extra, value: 5 }) { id } }
*** subscribe: {"data":{"created":{"id":"item-7fc5e3bb-286f-48fe-b027-9ca34dcc6451","category":"Extra","value":5}}}
** asyncIterator next
* query: { find(id: "item-7fc5e3bb-286f-48fe-b027-9ca34dcc6451") { id category value } }
* query: { find(id: "item-invalid") { id category value } }

rusty_v8 を使って Rust から JavaScript を実行

Node.js の製作者が新たに作り直した Deno という JavaScript/TypeScript 実行環境があります。

Deno の内部では、V8 JavaScript エンジンの呼び出しに rusty_v8 という Rust 用バインディングを使っていたので、今回はこの rusty_v8 を使って Rust コード内で JavaScript コードを実行してみました。

今回のサンプルコードは http://github.com/fits/try_samples/tree/master/blog/20200705/

設定

rusty_v8 を使うための Cargo 用の dependencies 設定は以下のようになります。

Cargo.toml
・・・
[dependencies]
rusty_v8 = "0.6"

JavaScript コード実行

以下の JavaScript コードを実行し結果(1 ~ 5 の合計値)を出力する処理を Rust で実装してみます。

実行する JavaScript コード
const vs = [1, 2, 3, 4, 5]
console.log(vs)
vs.reduce((acc, v) => acc + v, 0)

基本的に、下記 V8 API による手順と同様の処理を rusty_v8 の API で実装すればよさそうです。

V8 エンジンのインスタンスである Isolate(独自のヒープを持ち他のインスタンスとは隔離される)、GC で管理するオブジェクトへの参照をまとめて管理する HandleScopeサンドボックス化された実行コンテキストの Context(組み込みのオブジェクト・関数を管理)をそれぞれ作成していきます。

そして、JavaScript のコードを Script::compileコンパイルして、run で実行します。

run の戻り値は Option<Local<Value>> となっているので、ここでは to_string を使って JavaScript の String(Option<Local<rusty_v8::String>>)として取得し、rusty_v8::String を to_rust_string_lossy を使って Rust の String へ変換して出力しています。

src/sample1.rs
use rusty_v8 as v8;

fn main() {
    let platform = v8::new_default_platform().unwrap();
    v8::V8::initialize_platform(platform);
    v8::V8::initialize();

    let isolate = &mut v8::Isolate::new(Default::default());

    let scope = &mut v8::HandleScope::new(isolate);
    let context = v8::Context::new(scope);
    let scope = &mut v8::ContextScope::new(scope, context);

    // JavaScript コード
    let src = r#"
        const vs = [1, 2, 3, 4, 5]
        console.log(vs)
        vs.reduce((acc, v) => acc + v, 0)
    "#;

    v8::String::new(scope, src)
        .map(|code| {
            println!("code: {}", code.to_rust_string_lossy(scope));
            code
        })
        .and_then(|code| v8::Script::compile(scope, code, None)) //コンパイル
        .and_then(|script| script.run(scope)) //実行
        .and_then(|value| value.to_string(scope)) // rusty_v8::Value を rusty_v8::String へ
        .iter()
        .for_each(|s| println!("result: {}", s.to_rust_string_lossy(scope)));
}

実行

複数の実行ファイルに対応した下記 Cargo.toml を使って実行します。

Cargo.toml
[package]
・・・
default-run = "sample1"

[dependencies]
rusty_v8 = "0.6"

[[bin]]
name = "sample1"
path = "src/sample1.rs"

[[bin]]
name = "sample2"
path = "src/sample2.rs"

[[bin]]
name = "sample3"
path = "src/sample3.rs"

sample1(src/sample1.rs)の実行結果は以下の通りです。

console.log に関しては何も処理されていませんが、JavaScript の実行結果は取得できています。

sample1 実行結果
> cargo run

・・・
     Running `target\debug\sample1.exe`
code:
        const vs = [1, 2, 3, 4, 5]
        console.log(vs)
        vs.reduce((acc, v) => acc + v, 0)

result: 15

Inspector 機能

次に、console.log されたログメッセージを Rust から出力するようにしてみます。

V8 で console.log (デバッグコンソールへのログ出力)のようなデバッグ機能を処理するには Inspector という機能を使うようです。

rusty_v8 では、console.log 等の呼び出し時に V8InspectorClientImpl トレイトの console_api_message が呼びされるようになっているため、これを実装した構造体のインスタンスから V8Inspector を作成して context_created を実行する事で実現できます。

context_created のシグネチャfn context_created(&mut self, context: Local<Context>, context_group_id: i32, human_readable_name: StringView) となっており、context_group_id 引数へ指定した値が、console_api_message の context_group_id 引数の値となります。(human_readable_name の用途に関してはよく分からなかった)

また、console_api_message の level 引数の値はログレベル(V8 API の MessageErrorLevel) のようです。

ちなみに、V8InspectorClientImpl トレイトは basebase_mut の実装が必須でした。(console_api_message は空実装されている)

src/sample2.rs
use rusty_v8 as v8;
use rusty_v8::inspector::*;

struct InspectorClient(V8InspectorClientBase);

impl InspectorClient {
    fn new() -> Self {
        Self(V8InspectorClientBase::new::<Self>())
    }
}

impl V8InspectorClientImpl for InspectorClient {
    fn base(&self) -> &V8InspectorClientBase {
        &self.0
    }

    fn base_mut(&mut self) -> &mut V8InspectorClientBase {
        &mut self.0
    }

    fn console_api_message(&mut self, _context_group_id: i32, 
        _level: i32, message: &StringView, _url: &StringView, 
        _line_number: u32, _column_number: u32, _stack_trace: &mut V8StackTrace) {
        // ログメッセージの出力
        println!("{}", message);
    }
}

fn main() {
    ・・・

    let isolate = &mut v8::Isolate::new(Default::default());

    // V8Inspector の作成
    let mut client = InspectorClient::new();
    let mut inspector = V8Inspector::create(isolate, &mut client);

    let scope = &mut v8::HandleScope::new(isolate);
    let context = v8::Context::new(scope);
    let scope = &mut v8::ContextScope::new(scope, context);

    // context_created の実行
    inspector.context_created(context, 1, StringView::empty());

    let src = r#"
        const vs = [1, 2, 3, 4, 5]
        console.log(vs)
        vs.reduce((acc, v) => acc + v, 0)
    "#;

    ・・・
}

実行結果は以下の通り。

console.log(vs) を処理して 1,2,3,4,5 が出力されるようになりました。

sample2 実行結果
> cargo run --bin sample2

・・・
     Running `target\debug\sample2.exe`
code:
        const vs = [1, 2, 3, 4, 5]
        console.log(vs)
        vs.reduce((acc, v) => acc + v, 0)

1,2,3,4,5
result: 15

最後に、context_group_idlevel の値も出力するようにしてみます。

src/sample3.rs
・・・

impl V8InspectorClientImpl for InspectorClient {
    ・・・

    fn console_api_message(&mut self, context_group_id: i32, 
        level: i32, message: &StringView, _url: &StringView, 
        _line_number: u32, _column_number: u32, _stack_trace: &mut V8StackTrace) {

            println!(
                "*** context_group_id={}, level={}, message={}", 
                context_group_id, 
                level, 
                message
            );
    }
}

fn main() {
    ・・・

    inspector.context_created(context, 123, StringView::empty());

    let src = r#"
        console.log('log')
        console.debug('debug')
        console.info('info')
        console.error('error')
        console.warn('warn')
    "#;

    v8::String::new(scope, src)
        .and_then(|code| v8::Script::compile(scope, code, None))
        .and_then(|script| script.run(scope))
        .and_then(|value| value.to_string(scope))
        .iter()
        .for_each(|s| println!("result: {}", s.to_rust_string_lossy(scope)));
}

実行結果は以下の通りです。

sample3 実行結果
> cargo run --bin sample3

・・・
     Running `target\debug\sample3.exe`
*** context_group_id=123, level=4, message=log
*** context_group_id=123, level=2, message=debug
*** context_group_id=123, level=4, message=info
*** context_group_id=123, level=8, message=error
*** context_group_id=123, level=16, message=warn
result: undefined

Rust で WASI 対応の WebAssembly を作成して実行

Rust で WASI 対応の WebAssembly を作って、スタンドアロン実行や Web ブラウザ上での実行を試してみました。

WASI(WebAssembly System Interface) は WebAssembly のコードを様々なプラットフォームで実行するためのインターフェースで、これに対応した WebAssembly であれば Web ブラウザ外で実行できます。

Rust で WASI 対応の WebAssembly を作るのは簡単で、ビルドターゲットに wasm32-wasi を追加しておいて、rustccargo build によるビルド時に --target wasm32-wasi を指定するだけでした。

wasm32-wasi の追加
> rustup target add wasm32-wasi

標準出力へ文字列を出力するだけの下記サンプルコードを --target wasm32-wasi でビルドすると sample1.wasm ファイルが作られました。

sample1.rs
fn main() {
    for i in 1..=3 {
        println!("count-{}", i);
    }

    print!("aaa");
    print!("bbb");
}
WASI 対応の WebAssembly として sample1.rs をビルド
> rustc --target wasm32-wasi sample1.rs

なお、今回のビルドに使用した Rust のバージョンは以下の通りです。

  • Rust 1.43.0

また、使用したソースコードhttp://github.com/fits/try_samples/tree/master/blog/20200429/ に置いてあります。

(1) スタンドアロン用ランタイムで実行

sample1.wasm を WebAssembly のランタイム wasmtimewasmer でそれぞれ実行してみます。

(1-a) wasmtime で実行

wasmtime v0.15.0 による sample1.wasm 実行結果
> wasmtime sample1.wasm
count-1
count-2
count-3
aaabbb

(1-b) wasmer で実行

wasmer v0.16.2 による sample1.wasm 実行結果
> wasmer sample1.wasm
count-1
count-2
count-3
aaabbb

どちらのランタイムでも問題なく実行できました。

(2) Web ブラウザ上で実行

次は、sample1.wasm を外部ライブラリ等を使わずに Web ブラウザ上で実行してみます。

主要な Web ブラウザや Node.js は JavaScript 用の WebAssembly API に対応済みのため、WebAssembly を実行可能です。

WASI 対応 WebAssembly の場合、実行対象の WebAssembly がインポートしている WASI の関数(の実装)を WebAssembly インスタンス化関数(WebAssembly.instantiate()WebAssembly.instantiateStreaming())の第二引数(引数名 importObject)として渡す必要があるようです。

(2-a) WebAssembly のインポート内容を確認

WebAssembly.compile() 関数で取得した WebAssembly.Module オブジェクトを WebAssembly.Module.imports() 関数へ渡す事で、その WebAssembly がインポートしている内容を取得できます。

ここでは、以下の Node.js スクリプトを使って WebAssembly のインポート内容を確認してみました。

wasm_listup_imports.js (WebAssembly のインポート内容を出力)
const fs = require('fs')

const wasmFile = process.argv[2]

const run = async () => {
    const module = await WebAssembly.compile(fs.readFileSync(wasmFile))

    const imports = WebAssembly.Module.imports(module)

    console.log(imports)
}

run().catch(err => console.error(err))

sample1.wasm へ適用してみると以下のような結果となりました。

インポート内容の出力結果(Node.js v12.16.2 で実行)
> node wasm_listup_imports.js sample1.wasm
[
  {
    module: 'wasi_snapshot_preview1',
    name: 'proc_exit',
    kind: 'function'
  },
  {
    module: 'wasi_snapshot_preview1',
    name: 'fd_write',
    kind: 'function'
  },
  {
    module: 'wasi_snapshot_preview1',
    name: 'fd_prestat_get',
    kind: 'function'
  },
  {
    module: 'wasi_snapshot_preview1',
    name: 'fd_prestat_dir_name',
    kind: 'function'
  },
  {
    module: 'wasi_snapshot_preview1',
    name: 'environ_sizes_get',
    kind: 'function'
  },
  {
    module: 'wasi_snapshot_preview1',
    name: 'environ_get',
    kind: 'function'
  }
]

この結果から、sample1.wasm は以下のようにしてインスタンス化できる事になります。

WebAssembly インスタンス化の例
const importObject = {
    wasi_snapshot_preview1: {
        proc_exit: () => {・・・},
        fd_write: () => {・・・},
        fd_prestat_get: () => {・・・},
        fd_prestat_dir_name: () => {・・・},
        environ_sizes_get: () => {・・・},
        environ_get: () => {・・・}
    }
}

WebAssembly.instantiate(・・・, importObject)
    ・・・

(2-b) fd_write 関数の実装

Rust の println! で呼び出される WASI の関数は fd_write なので、これを実装してみます。

fd_write の引数は 4つで、第一引数 fd は出力先のファイルディスクリプタで標準出力の場合は 1、第二引数 iovs は出力内容へのポインタ、第三引数 iovsLen は出力内容の数、第四引数 nwritten は出力済みのバイト数を設定するポインタとなっています。

なお、ポインタの対象は WebAssembly.instantiate() で取得した WebAssembly のインスタンスに含まれている WebAssembly.Memory です。

出力内容は iovs ポインタの位置から 4バイト毎に以下のような並びで情報が格納されているようなので、これを基に出力対象の文字列を取得して出力する事になります。

  • 1個目の出力内容の格納先ポインタ(4バイト)
  • 1個目の出力内容のバイトサイズ(4バイト)
  • ・・・
  • iovsLen 個目の出力内容の格納先ポインタ(4バイト)
  • iovsLen 個目の出力内容のバイトサイズ(4バイト)

何処まで処理を行ったか(出力したか)を返すために、nwritten ポインタの位置へ出力の完了したバイトサイズを設定します。

fd_write の実装例(wasmInstance には WebAssembly のインスタンスを設定)
・・・
fd_write: (fd, iovs, iovsLen, nwritten) => {
    const memory = wasmInstance.exports.memory.buffer
    const view = new DataView(memory)

    const sizeList = Array.from(Array(iovsLen), (v, i) => {
        const ptr = iovs + i * 8

        // 出力内容の格納先のポインタ取得
        const bufStart = view.getUint32(ptr, true)
        // 出力内容のバイトサイズを取得
        const bufLen = view.getUint32(ptr + 4, true)

        const buf = new Uint8Array(memory, bufStart, bufLen)

        // 出力内容の String 化
        const msg = String.fromCharCode(...buf)

        // 出力
        console.log(msg)

        return buf.byteLength
    })

    // 出力済みのバイトサイズ合計
    const totalSize = sizeList.reduce((acc, v) => acc + v)

    // 出力済みのバイトサイズを設定
    view.setUint32(nwritten, totalSize, true)

    return 0
},
・・・

最終的な HTML は下記のようになりました。

fd_write 以外の WASI 関数を空実装にして main 関数を呼び出して実行するようにしていますが、WASI の仕様としては _start 関数を呼び出すのが正しいようです ※。(WASI Application ABI 参照)

 ※ _start 関数を使う場合、fd_prestat_get 等の実装も必要となります

WebAssembly がインポートしている WASI 関数の実装をインスタンス化時(WebAssembly.instantiateStreaming)に渡す事になりますが、WASI の関数(fd_write 等)はインスタンス化の結果を使って処理する点に注意が必要です。

index.html(main 関数版)
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
</head>
<body>
  <h1>WASI WebAssembly Sample</h1>
  <div id="res"></div>

  <script>
    const WASM_URL = './sample1.wasm'

    const wasiObj = {
      wasmInstance: null,
      importObject: {
        wasi_snapshot_preview1: {
          fd_write: (fd, iovs, iovsLen, nwritten) => {
            console.log(`*** call fd_write: fd=${fd}, iovs=${iovs}, iovsLen=${iovsLen}, nwritten=${nwritten}`)

            const memory = wasiObj.wasmInstance.exports.memory.buffer
            const view = new DataView(memory)

            const sizeList = Array.from(Array(iovsLen), (v, i) => {
              const ptr = iovs + i * 8

              const bufStart = view.getUint32(ptr, true)
              const bufLen = view.getUint32(ptr + 4, true)

              const buf = new Uint8Array(memory, bufStart, bufLen)

              const msg = String.fromCharCode(...buf)

              // 出力
              console.log(msg)
              document.getElementById('res').innerHTML += `<p>${msg}</p>`

              return buf.byteLength
            })

            const totalSize = sizeList.reduce((acc, v) => acc + v)

            view.setUint32(nwritten, totalSize, true)

            return 0
          },
          proc_exit: () => {},
          fd_prestat_get: () => {},
          fd_prestat_dir_name: () => {},
          environ_sizes_get: () => {},
          environ_get: () => {}
        }
      }
    }

    WebAssembly.instantiateStreaming(fetch(WASM_URL), wasiObj.importObject)
      .then(res => {
        console.log(res)

        // fd_write で参照できるようにインスタンスを wasmInstance へ設定
        wasiObj.wasmInstance = res.instance

        // main 関数の実行
        wasiObj.wasmInstance.exports.main()
      })
      .catch(err => console.error(err))
  </script>
</body>
</html>

main 関数の代わりに _start 関数を呼び出す場合は下記のようになりました。

_start 関数版の場合、fd_prestat_get の実装が重要となります ※。

 ※ fd_prestat_get を正しく実装していないと、
    fd_prestat_get の呼び出しが延々と繰り返されてしまいました

今回はファイル等を使っていないので(file descriptor 3 以降を開いていない)、fd_prestat_get は単に 8(WASI_EBADF, Bad file descriptor)を返すだけで良さそうです。

index2.html(_start 関数版)
・・・
  <script>
    const WASM_URL = './sample1.wasm'

    const wasiObj = {
      wasmInstance: null,
      importObject: {
        wasi_snapshot_preview1: {
          ・・・
          fd_prestat_get: () => 8,
          ・・・
        }
      }
    }

    WebAssembly.instantiateStreaming(fetch(WASM_URL), wasiObj.importObject)
      .then(res => {
        console.log(res)

        wasiObj.wasmInstance = res.instance
        // _start 関数の実行
        wasiObj.wasmInstance.exports._start()
      })
      .catch(err => console.error(err))
  </script>
・・・

(2-c) 実行

上記の .html ファイルを Web ブラウザで直接開いても WebAssembly を実行できないため、HTTP サーバーを使う事になります。

更に、Web ブラウザ上で WebAssembly を実行するには、.wasm ファイルを MIME Type application/wasm で取得する必要があるようです。

Python の http.server は application/wasm に対応していたため(Python 3.8.2 と 3.7.6 で確認)、以下のスクリプトで HTTP サーバーを立ち上げる事にしました。

web_server.py
import http.server
import socketserver

PORT = 8080

Handler = http.server.SimpleHTTPRequestHandler

with socketserver.TCPServer(("", PORT), Handler) as httpd:
    print(f"start server port:{PORT}")
    httpd.serve_forever()
HTTP サーバー起動(Python 3.8.2 で実行)
> python web_server.py
start server port:8080

Web ブラウザ(Chrome)で http://localhost:8080/index.html へアクセスしたところ(index2.html でも同様)、sample1.wasm の実行を確認できました。

Chrome の実行結果

f:id:fits:20200429200738p:plain

(3) Node.js で組み込み実行

次は、Node.js で WebAssembly を組み込み実行してみます。

(3-a) fd_write 実装

上記 index2.html の処理をベースにローカルの .wasm ファイルを読み込んで実行するようにしました。

sample1.wasm のインポート内容に合わせたものなので、インポート内容の異なる WebAssembly の実行には使えません。

wasm_run_sample.js
const fs = require('fs')

const WASI_ESUCCESS = 0;
const WASI_EBADF = 8; // Bad file descriptor

const wasmFile = process.argv[2]

const wasiObj = {
    wasmInstance: null,
    importObject: {
        wasi_snapshot_preview1: {
            fd_write: (fd, iovs, iovsLen, nwritten) => {
                ・・・
                
                const sizeList = Array.from(Array(iovsLen), (v, i) => {
                    ・・・
                    
                    process.stdout.write(msg)
                    
                    return buf.byteLength
                })
                
                ・・・
                
                return WASI_ESUCCESS
            },
            ・・・
            fd_prestat_get: (fd, bufPtr) => { 
                console.log(`*** call fd_prestat_get: fd=${fd}, bufPtr=${bufPtr}`)
                return WASI_EBADF
            },
            ・・・
        }
    }
}

const buf = fs.readFileSync(wasmFile)

WebAssembly.instantiate(buf, wasiObj.importObject)
    .then(res => {
        wasiObj.wasmInstance = res.instance
        wasiObj.wasmInstance.exports._start()
    })
    .catch(err => console.error(err))
実行結果(Node.js v12.16.2 で実行)
> node wasm_run_sample.js sample1.wasm
*** call fd_prestat_get : fd=3, bufPtr=1048568
*** call fd_write: fd=1, iovs=1047968, iovsLen=1, nwritten=1047948
count-1
*** call fd_write: fd=1, iovs=1047968, iovsLen=1, nwritten=1047948
count-2
*** call fd_write: fd=1, iovs=1047968, iovsLen=1, nwritten=1047948
count-3
*** call fd_write: fd=1, iovs=1048432, iovsLen=1, nwritten=1048412
aaabbb

(3-b) Wasmer-JS 使用

Wasmer-JS@wasmer/wasi モジュールを使って、もっと汎用的に組み込み実行できるようにしてみます。

@wasmer/wasi インストール例
> npm install @wasmer/wasi

@wasmer/wasi の WASI を使う事で、インポート内容に合わせた WASI 関数の取得や _start 関数の呼び出しを任せる事ができます。

run_wasmer_js/index.js
const fs = require('fs')
const { WASI } = require('@wasmer/wasi')

const wasmFile = process.argv[2]

const wasi = new WASI()

const run = async () => {
    const module = await WebAssembly.compile(fs.readFileSync(wasmFile))
    // インポート内容に合わせた WASI 関数の実装を取得
    const importObject = wasi.getImports(module)

    const instance = await WebAssembly.instantiate(module, importObject)

    // 実行
    wasi.start(instance)
}

run().catch(err => console.error(err))
実行結果(Node.js v12.16.2 で実行)
> node index.js ../sample1.wasm
count-1
count-2
count-3
aaabbb

(4) 標準出力以外の機能

最後に、現時点でどんな機能を使えるのか気になったので、いくつか試してみました。

まず、TcpStream を使ったコードの wasm32-wasi ビルドは一応成功しました。

sample2.rs
use std::net::TcpStream;

fn main() {
    let res = TcpStream::connect("127.0.0.1:8080");
    println!("{:?}", res);
}
sample2.rs ビルド
> rustc --target wasm32-wasi sample2.rs

ただし、実行してみると以下のような結果となりました。(wasmtime 以外で実行しても同じ)

wasmtime による sample2.wasm 実行結果
> wasmtime sample2.wasm

Err(Custom { kind: Other, error: "operation not supported on wasm yet" })

Rust のソースコードで該当(すると思われる)箇所を確認してみると、unsupported() を返すようになっていました。

https://github.com/rust-lang/rust/blob/master/src/libstd/sys/wasi/net.rs(2020/4/26 時点)
・・・
impl TcpStream {
    pub fn connect(_: io::Result<&SocketAddr>) -> io::Result<TcpStream> {
        unsupported()
    }
    ・・・
}
・・・

https://github.com/rust-lang/rust/blob/master/src/libstd/sys/wasi/ のソースを確認してみると、(スレッド系の処理等)他にも未対応の機能がいくつもありました。

一方で、環境変数・システム時間・スリープ処理は使えそうだったので、以下のコードで確認してみました。

sample3.rs
use std::env;
use std::thread::sleep;
use std::time::{ Duration, SystemTime };

fn main() {
    // 環境変数 SLEEP_TIME からスリープする秒数を取得(デフォルトは 1)
    let sleep_sec = env::var("SLEEP_TIME").ok()
                         .and_then(|v| v.parse::<u64>().ok())
                         .unwrap_or(1);

    // システム時間の取得
    let time = SystemTime::now();

    println!("start: sleep {}s", sleep_sec);

    // スリープの実施
    sleep(Duration::from_secs(sleep_sec));

    // 経過時間の出力
    match time.elapsed() {
        Ok(s) => println!("end: elapsed {}s", s.as_secs()),
        Err(e) => println!("error: {:?}", e),
    }
}
sample3.rs のビルド
> rustc --target wasm32-wasi sample3.rs

wasmtime では正常に実行できましたが、wasmer は今のところスリープ処理に対応していないようでエラーとなりました。

ちなみに、環境変数はどちらのコマンドも --env で指定できました。

wasmtime v0.15.0 による sample3.wasm 実行結果
> wasmtime --env SLEEP_TIME=5 sample3.wasm
start: sleep 5s
end: elapsed 5s
wasmer v0.16.2 による sample3.wasm 実行結果
> wasmer sample3.wasm --env SLEEP_TIME=5
start: sleep 5s
thread 'main' panicked at 'not yet implemented: Polling not implemented for clocks yet', lib\wasi\src\syscalls\mod.rs:2373:21
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace.
Error: error: "unhandled trap at 7fffd474a799 - code #e06d7363: unknown exception code"