Problems when import BERT model from PyTorch Relay

Hi, I was trying to import bert-base-uncased by PyTorch relay. It said that “NotImplementedError: The following operators are not implemented: [‘prim::ImplicitTensorToNum’]” I can’t find any useful information about ‘ImplicitTensorToNum’. @siju-samuel, could you help me to solve this?

The code is as follow:

from tvm import relay
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
import logging

logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet',
                          '##eer', '[SEP]']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

print('Use BertModel to get hidden states')

# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

# Predict hidden states features for each layer
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
# We have a hidden states for each of the 12 layers in model bert-base-uncased
assert len(encoded_layers) == 12

print('Use BertForMaskedLM')

# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()


# Predict all tokens
with torch.no_grad():
    predictions = model(tokens_tensor, segments_tensors)

# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

# print(model)
scripted_model = torch.jit.trace(model, (tokens_tensor, segments_tensors)).eval()
print('Scripted_model finish')
input_1 = 'input_ids'
input_2 = 'input.2'
shape_list = [(input_1, [1, 14]), (input_2, [1, 14])]

mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

PR 5603 will help to solve the issue of prim::ImplicitTensorToNum

If you come across issue with matmul, you can merge PR 5604 as well.

Thanks!

I try to fix it by

def _tensortonum():
    def _impl(inputs, input_types):
        return inputs[0]
    return _impl

Another error accurs:

TypeError: int() argument must be a string, a bytes-like object or a number, not 'Call'
'58', %58 : Scalar = prim::ImplicitTensorToNum(%57),
......
'position_ids.1', %position_ids.1 : Long(7) = aten::arange(%59, %58, %60, %61, %62, %63, %64)

‘aten::arange’ will _create_typed_const(inputs[1], dtype), but ‘inputs[1]’ is ‘Call’, that’s the reason of ‘TypeError’ mentioned above.

So I use the following implementation.

def _tensortonum():
    def _impl(inputs, input_types):
        tmp = _infer_value(inputs[0], None)
        return np.array(tmp).astype(np.str)
    return _impl

It works, but I’m not sure if it is correct because when it goes to ‘aten::view’, the following error accurs:

TVMError: Check failed: ObjectTypeChecker<TObjectRef>: :Check(ptr): Expect List[IntImm] but get Array

It seems like something is wrong when it goes to: _op.transform.reshape(data, new_shape)

from tvm import relay import torch from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM import logging

logging.basicConfig(level=logging.INFO)

I used the code you pasted in the first issue and its able to generate complete relay function. Im not able to get error you mentioned. How to reproduce the below issue

TypeError: int() argument must be a string, a bytes-like object or a number, not ‘Call’

Sorry I didn’t make it clear. The code pasted in the first issue works well after adopting your solution. Then I tested ‘gpt2’ model, it reported an error:

TypeError: int() argument must be a string, a bytes-like object or a number, not ‘Call’

The test code is shown below:

from tvm import relay
import torch
from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

text = "What is the fastest car in the"
indexed_tokens = tokenizer.encode(text)
tokens_tensor = torch.tensor([indexed_tokens])

model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()

with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0]

predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
print(predicted_text)

scripted_model = torch.jit.trace(model, tokens_tensor).eval()

input_name = 'input_ids'
input_shapes = [(input_name, [1, 7])]
mod, params = relay.frontend.from_pytorch(scripted_model, input_shapes)

I had hit the same error when loading BERT TypeError: int() argument must be a string, a bytes-like object or a number, not 'Any' It was resolved after I re-built tvm w/ llvm

1 Like