Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request]: Allow specification of a custom model inference method for a RunInference ModelHandler #22572

Closed
agvdndor opened this issue Aug 3, 2022 · 12 comments
Assignees
Labels
core done & done Issue has been reviewed after it was closed for verification, followups, etc. ml new feature P2 python run-inference

Comments

@agvdndor
Copy link
Contributor

agvdndor commented Aug 3, 2022

What would you like to happen?

The current implementation of RunInference provides model handlers for PyTorch and Sklearn models. These handlers assume that the method to call for inference is fixed:

  • Pytorch: Do a forward pass by calling the __call__ method -> output = torch_model(input)
  • Sklearn: call the model's predict method -> output = sklearn_model.predict(input)

However in some cases we want to provide a custom method for RunInference to call.
Two examples:

  1. A number of pretrained models loaded with the Huggingface transformers library recommend using the generate() method. From the Huggingface docs on the T5 mode:

    At inference time, it is recommended to use generate(). This method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder and auto-regressively generates the decoder output.

    
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    
    input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
    outputs = model.generate(input_ids)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    Das Haus ist wunderbar.
    
  2. Using OpenAI's CLIP model which is implemented as a torch model we might not want to execute the normal forward pass to encode both images and text image_embedding, text_embedding = clip_model(image, text) but instead only compute the image embeddings image_embedding = clip_model.encode_image(image).

Solution: Allowing the user to specify the inference_fn when creating a ModelHandler would enable this usage.

Issue Priority

Priority: 2

Issue Component

Component: sdk-py-core

@yeandy
Copy link
Contributor

yeandy commented Aug 3, 2022

This also applies to scikit-learn. For example, RandomForestClassifier has predict(X), predict_proba(X) or predict_log_proba(X), and other less common functions like apply(X), etc.

@yeandy
Copy link
Contributor

yeandy commented Aug 3, 2022

Parent issue: #22117

@TheNeuralBit
Copy link
Member

I think this would be difficult to do in a general (cross-ModelHandler) way as each ModelHandler is responsible for invoking it's model, and they currently have different ways of doing so.

sklearn calls a predict method:

predictions = model.predict(vectorized_batch)

pytorch calls the model like a callable (which then uses the forward method IIUC?):

predictions = model(**key_to_batched_tensors, **inference_args)

I think the best we could do to solve the problem generally is establish some kind of convention.

It's also worth noting that the generate method is a property of hugging face's GenerationMixin, not a part of the torch.nn.Module API, which is in our contract:

Is a separate generation modelhandler a better solution?

@agvdndor
Copy link
Contributor Author

agvdndor commented Aug 5, 2022

I could imagine three options:

  1. Stick to the current contract and assume that users will subclass the existing handlers to accommodate their model when it falls outside of the contract.
  2. Create a separate GenerationModelHandler. I'm not a fan of this approach. As @yeandy commented, there's a lot of fairly common options out there: predict_proba, apply, encode, decode, generate... So this might not scale too well and lead to a proliferation of model handlers
  3. Let the user pass the model_inference_fn during initialization as an optional kwarg.

Personally, I'd prefer option three. Something like this:

from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor

model_handler = PytorchModelHandlerTensor(
class PytorchModelHandlerTensor(
    state_dict_path="<path-to-state-dict-file>",
    model_class=DistilBertForSequenceClassification,
    model_params={"config": DistilBertConfig("<path-to-config-file>"},
    model_inference_fn=DistilBertForSequenceClassification.generate)

Wyt?

@yeandy
Copy link
Contributor

yeandy commented Aug 11, 2022

pytorch calls the model like a callable (which then uses the forward method IIUC?):

Correct.


And thanks @agvdndor for the detailed suggestions!

  • For GenerationModelHandler I agree that it does not scale well.
  • A lambda like model_inference_fn could work. The change itself shouldn't be that hard to implement. However, we need to ask ourselves -- at what point are we doing too much to address these custom use cases? On the one hand, I recognize that HuggingFace is very popular, if I'd be remiss if to see a bunch of potential RunInference users turned away because of how difficult it is to plug in a HuggingFace model into PytorchModelHandlerTensor. On the other hand, if we can capture 80% of use cases without having this custom infer function, that might be good enough? If users do require a more tailored solution, then they probably should be writing up their own DoFn anyway (inspired, of course, by our own implementation). @robertwb What are your thoughts on adding something like a Generation ModleHandler versus a model_inference_fn?

There are some other workarounds that users could do. Would these be sufficient solutions to this?

  1. Create a wrapper class that inherits from torch.nn.Module, and then override its forward() method and calls the model's intended inference function. (Note: this code is just an example and isn't necessarily the best or correct way to do this.)
class Tacotron2Wrapper(torch.nn.Module):
  def __init__(self, model=tacotron2):
    super().__init__()
    self._model = model

  def forward(self, inputs, input_lengths):
    mel, _, _ = self._model.infer(inputs, input_lengths)
    return mel
  1. Inherit ModelHandler, and change the run_inference function to call model.infer() instead of model(). This might be easier than the first solution, but does require the user to copy the other logic correctly.
  def run_inference(
      self,
      batch: Sequence[torch.Tensor],
      model: torch.nn.Module,
      inference_args: Optional[Dict[str, Any]] = None
  ) -> Iterable[PredictionResult]:
    inference_args = {} if not inference_args else inference_args

    batched_tensors = torch.stack(batch)
    batched_tensors = _convert_to_device(batched_tensors, self._device)
    predictions = model.infer(batched_tensors, **inference_args)
    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

@damccorm
Copy link
Contributor

damccorm commented Sep 2, 2022

Generally, my take here is that we should do option 3 here and allow users to pass in a custom function. Basically:

  1. This is a reasonably common pattern
  2. Adding support shouldn't be too hard
  3. Asking users to create their own handler (or model wrapper) any time they want to use a different method doesn't scale well. Some might contribute it back to the community, most won't, and even with those who do we're incurring an extra review/maintenance burden. It also significantly raises the bar for first time users who now would need to understand the handler internals.
  4. Supporting this doesn't meaningfully make it harder for users in the simple use case (omitting this param should do nothing).

@jrmccluskey could you pick this one up when you have the bandwidth?

@jrmccluskey
Copy link
Contributor

Looking into this a little bit, it's doable for each handler type but the end result is somewhat restrictive for the user. The provided function is going to have to take the same arguments in the same position as the current inference methods. For the given examples discussed this isn't a huge issue (unless HuggingFace users really want to use the 30+ optional generate() parameters) and will likely cover a large number of use cases, but we'll still have some advanced users who will want more tuning and will likely turn to bespoke options.

It also looks like providing the alternate inference function will need to be done at run_inference call-time, not handler init-time, since the scikit-learn and PyTorch approaches are using functions from specific instances of their respective models. Can't specify the function until you have the model, unless I'm missing something.

@damccorm
Copy link
Contributor

damccorm commented Sep 7, 2022

The provided function is going to have to take the same arguments in the same position as the current inference methods. For the given examples discussed this isn't a huge issue (unless HuggingFace users really want to use the 30+ optional generate() parameters) and will likely cover a large number of use cases, but we'll still have some advanced users who will want more tuning and will likely turn to bespoke options.

I'm not 100% sure this is true, for example I could imagine an approach where we let users pass in some sort of function like:
lambda model, batched_tensors, inference_args: model.generate(...). Regardless, I think the optional inference_args probably give users enough flexibility here, though it would be good to validate that against an existing model example.

It also looks like providing the alternate inference function will need to be done at run_inference call-time, not handler init-time, since the scikit-learn and PyTorch approaches are using functions from specific instances of their respective models. Can't specify the function until you have the model, unless I'm missing something.

You could probably do something with getattr where you pass in the function name via string, though I don't love that approach since its not very flexible w/ parameters. You could also again let them pass in a function. Its a little more work for a user, but might be worth the customizability (and for users that don't need it, their function would just be lambda model, batched_tensors, **inference_args: model.doSomething(batched_tensors, **inference_args)

Thoughts?

@damccorm damccorm assigned jrmccluskey and unassigned yeandy Sep 9, 2022
@jrmccluskey
Copy link
Contributor

I've put together a brief doc discussing my perspective and preferred solution for this here - https://docs.google.com/document/d/1YYGsF20kminz7j9ifFdCD5WQwVl8aTeCo0cgPjbdFNU/edit?usp=sharing

PTAL

@damccorm
Copy link
Contributor

@jrmccluskey could you please file a follow up issue to update our notebooks to use this feature once this is released?

@jrmccluskey
Copy link
Contributor

Filed as #24334

@damccorm
Copy link
Contributor

Thanks!

@damccorm damccorm added the done & done Issue has been reviewed after it was closed for verification, followups, etc. label Nov 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core done & done Issue has been reviewed after it was closed for verification, followups, etc. ml new feature P2 python run-inference
Projects
None yet
Development

No branches or pull requests

6 participants