Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Replace SASRec TF with PyTorch version #2111

Open
1 of 3 tasks
miguelgfierro opened this issue Jun 17, 2024 · 4 comments
Open
1 of 3 tasks

[FEATURE] Replace SASRec TF with PyTorch version #2111

miguelgfierro opened this issue Jun 17, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@miguelgfierro
Copy link
Collaborator

miguelgfierro commented Jun 17, 2024

Description

SASRec tests are disabled: https://github.com/recommenders-team/recommenders/blob/main/tests/ci/azureml_tests/test_groups.py#L410
We could replace the TF algo with https://github.com/microsoft/UniRec/blob/main/unirec/model/sequential/sasrec.py

Expected behavior with the suggested feature

Branch: https://github.com/recommenders-team/recommenders/tree/miguel/sasrec_unirec

Tasks:

  • Create unit tests of the classes
  • Make sure the original script runs
  • Create a functional test with the minimal parts of the code training movielens

Other Comments

@miguelgfierro miguelgfierro added the enhancement New feature or request label Jun 17, 2024
@miguelgfierro
Copy link
Collaborator Author

pytest -s tests/unit/recommenders/models/test_unirec_model.py::test_sasrec_train --disable-warnings

@miguelgfierro
Copy link
Collaborator Author

miguelgfierro commented Jul 5, 2024

import cvxpy as cp
E   ModuleNotFoundError: No module named 'cvxpy'

solved with pip install cvxpy

another error:

FAILED tests/unit/recommenders/models/test_unirec_model.py::test_sasrec_train - ModuleNotFoundError: No module named 'feather'

solved by installing install feather-format

@miguelgfierro
Copy link
Collaborator Author

 @pytest.mark.gpu
    def test_sasrec_train(base_config, unirec_config_path):
        # config = copy.deepcopy(base_config)
        # yaml_file = os.path.join(unirec_config_path, "model", "SASRec.yaml")
        # config.update(load_yaml(yaml_file))
    
        # model = SASRec(config)
        import copy
        import datetime
        from recommenders.models.unirec.main import main
    
        GLOBAL_CONF = {
            # "config_dir": f"{os.path.join(unirec_config_path, 'unirec', 'config')}",
            "config_dir": unirec_config_path,
            "exp_name": "pytest",
            "checkpoint_dir": f'{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}',
            "model": "",
            "dataloader": "SeqRecDataset",
            "dataset": "",
            "dataset_path": os.path.join(unirec_config_path, "tests/.temp/data"),
            "output_path": "",
            "learning_rate": 0.001,
            "dropout_prob": 0.0,
            "embedding_size": 32,
            "hidden_size": 32,
            "use_pre_item_emb": 0,
            "loss_type": "bce",
            "max_seq_len": 10,
            "has_user_bias": 1,
            "has_item_bias": 1,
            "epochs": 1,
            "early_stop": -1,
            "batch_size": 512,
            "n_sample_neg_train": 9,
            "valid_protocol": "one_vs_all",
            "test_protocol": "one_vs_all",
            "grad_clip_value": 0.1,
            "weight_decay": 1e-6,
            "history_mask_mode": "autoagressive",
            "user_history_filename": "user_history",
            "metrics": "['hit@5;10', 'ndcg@5;10']",
            "key_metric": "ndcg@5",
            "num_workers": 4,
            "num_workers_test": 0,
            "verbose": 2,
            "neg_by_pop_alpha": 0.0,
            "conv_size": 10,  # for ConvFormer-series
        }
        config = copy.deepcopy(GLOBAL_CONF)
        config["task"] = "train"
        config["dataset_path"] = os.path.join(config["dataset_path"], "ml-100k")
        config["dataset"] = "ml-100k"
        config["model"] = "SASRec"
        config["output_path"] = os.path.join(unirec_config_path, f"tests/.temp/output/")
>       result = main.run(config)

tests/unit/recommenders/models/test_unirec_model.py:146: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
recommenders/models/unirec/main/main.py:676: in run
    res = main(config, accelerator)
recommenders/models/unirec/main/main.py:357: in main
    user2history, user2history_time = get_user_history(
recommenders/models/unirec/main/main.py:137: in get_user_history
    user2history, user2history_time = general.load_user_history(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

file_path = '/home/u/MS/recommenders/recommenders/models/unirec/config/tests/.temp/data/ml-100k', file_name = 'user_history', n_users = 940, format = 'user-item_seq', time_seq = 0

    def load_user_history(
        file_path, file_name, n_users=None, format="user-item", time_seq=0
    ):
        if os.path.exists(os.path.join(file_path, file_name + ".ftr")):
            df = pd.read_feather(os.path.join(file_path, file_name + ".ftr"))
        elif os.path.exists(os.path.join(file_path, file_name + ".pkl")):
            df = load_pkl_obj(os.path.join(file_path, file_name + ".pkl"))
        else:
>           raise NotImplementedError(
                "Unsupported user history file type: {0}".format(file_name)
            )
E           NotImplementedError: Unsupported user history file type: user_history

recommenders/models/unirec/utils/general.py:134: NotImplementedError
----------------------------------------------------------------------------------------------------- Captured log call -----------------------------------------------------------------------------------------------------
INFO     SASRec-pytest:logger.py:61 config={'gpu_id': 0, 'use_gpu': True, 'seed': 2022, 'state': 'INFO', 'verbose': 2, 'saved': True, 'use_tensorboard': False, 'use_wandb': False, 'init_method': 'normal', 'init_std': 0.02, 'init_mean': 0.0, 'scheduler': 'reduce', 'scheduler_factor': 0.1, 'time_seq': 0, 'seq_last': False, 'has_user_emb': False, 'has_user_bias': 1, 'has_item_bias': 1, 'use_features': False, 'use_text_emb': False, 'use_position_emb': True, 'load_pretrained_model': False, 'embedding_size': 32, 'hidden_size': 32, 'inner_size': 512, 'dropout_prob': 0.0, 'epochs': 1, 'batch_size': 512, 'learning_rate': 0.001, 'optimizer': 'adam', 'eval_step': 1, 'early_stop': -1, 'clip_grad_norm': None, 'weight_decay': 1e-06, 'num_workers': 4, 'persistent_workers': False, 'pin_memory': False, 'shuffle_train': False, 'use_pre_item_emb': 0, 'loss_type': 'bce', 'ccl_w': 150, 'ccl_m': 0.4, 'distance_type': 'dot', 'metrics': "['hit@5;10', 'ndcg@5;10']", 'key_metric': 'ndcg@5', 'test_protocol': 'one_vs_all', 'valid_protocol': 'one_vs_all', 'test_batch_size': 100, 'model': 'SASRec', 'dataloader': 'SeqRecDataset', 'max_seq_len': 10, 'history_mask_mode': 'autoagressive', 'tau': 1.0, 'enable_morec': 0, 'morec_objectives': ['fairness', 'alignment', 'revenue'], 'morec_objective_controller': 'PID', 'morec_ngroup': [10, 10, -1], 'morec_alpha': 0.1, 'morec_lambda': 0.2, 'morec_expect_loss': 0.2, 'morec_beta_min': 0.6, 'morec_beta_max': 1.3, 'morec_K_p': 0.01, 'morec_K_i': 0.001, 'morec_objective_weights': '[0.3,0.3,0.4]', 'n_layers': 2, 'n_heads': 16, 'hidden_dropout_prob': 0.5, 'attn_dropout_prob': 0.5, 'hidden_act': 'swish', 'layer_norm_eps': '1e-10', 'group_size': -1, 'n_items': 1017, 'n_neg_test_from_sampling': 0, 'n_neg_train_from_sampling': 0, 'n_neg_valid_from_sampling': 0, 'n_users': 940, 'test_file_format': 'user-item', 'train_file_format': 'user-item', 'user_history_file_format': 'user-item_seq', 'valid_file_format': 'user-item', 'base_model': 'GRU', 'freeze': 0, 'train_type': 'Base', 'config_dir': PosixPath('/home/u/MS/recommenders/recommenders/models/unirec/config'), 'exp_name': 'SASRec-pytest', 'checkpoint_dir': '2024-07-05_12-25-03', 'dataset': 'ml-100k', 'dataset_path': '/home/u/MS/recommenders/recommenders/models/unirec/config/tests/.temp/data/ml-100k', 'output_path': '/home/u/MS/recommenders/recommenders/models/unirec/config/tests/.temp/output/', 'n_sample_neg_train': 9, 'grad_clip_value': 0.1, 'user_history_filename': 'user_history', 'num_workers_test': 0, 'neg_by_pop_alpha': 0.0, 'conv_size': 10, 'task': 'train', 'cmd_args': {'base_model': 'GRU', 'freeze': 0, 'train_type': 'Base', 'config_dir': PosixPath('/home/u/MS/recommenders/recommenders/models/unirec/config'), 'exp_name': 'SASRec-pytest', 'checkpoint_dir': '2024-07-05_12-25-03', 'model': 'SASRec', 'dataloader': 'SeqRecDataset', 'dataset': 'ml-100k', 'dataset_path': '/home/u/MS/recommenders/recommenders/models/unirec/config/tests/.temp/data/ml-100k', 'output_path': '/home/u/MS/recommenders/recommenders/models/unirec/config/tests/.temp/output/', 'learning_rate': 0.001, 'dropout_prob': 0.0, 'embedding_size': 32, 'hidden_size': 32, 'use_pre_item_emb': 0, 'loss_type': 'bce', 'max_seq_len': 10, 'has_user_bias': 1, 'has_item_bias': 1, 'epochs': 1, 'early_stop': -1, 'batch_size': 512, 'n_sample_neg_train': 9, 'valid_protocol': 'one_vs_all', 'test_protocol': 'one_vs_all', 'grad_clip_value': 0.1, 'weight_decay': 1e-06, 'history_mask_mode': 'autoagressive', 'user_history_filename': 'user_history', 'metrics': "['hit@5;10', 'ndcg@5;10']", 'key_metric': 'ndcg@5', 'num_workers': 4, 'num_workers_test': 0, 'verbose': 2, 'neg_by_pop_alpha': 0.0, 'conv_size': 10, 'task': 'train', 'logger_time_str': '2024-07-05_122503', 'logger_rand': 91}, 'device': device(type='cpu'), 'logger_time_str': '2024-07-05_122503', 'logger_rand': 91}
INFO     SASRec-pytest:main.py:136 Loading user history from user_history ...
================================================================================================== short test summary info ==================================================================================================
FAILED tests/unit/recommenders/models/test_unirec_model.py::test_sasrec_train - NotImplementedError: Unsupported user history file type: user_history

@miguelgfierro
Copy link
Collaborator Author

miguelgfierro commented Aug 26, 2024

Work so far: staging...miguel/sasrec_unirec

Next step is to create a unit test called test_sasrec_train which should train sasrec with the minimum set of options on a dummy dataset. We should first make sure that the code with result = main.run(config) runs, and then, replace it with the minimum set of functions.

The steps should follow the structure of https://github.com/recommenders-team/recommenders/blob/main/examples/00_quick_start/sar_movielens.ipynb:

  • Data loading
  • split train and test iterators
  • instantiate the model
  • train the model

if we want, we can also do:

  • evaluate
  • get metrics

After this, we will create a notebook explaining an end 2 end case with a real dataset, and we will replace the TF notebook.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant