This project leverages a fine-tuned BART model to perform multitask learning for Amazon review summarization and rating prediction. The model is designed to generate concise summaries of review texts and predict ratings based on the provided content. To enhance training efficiency, Hugging Face’s Accelerate is used for distributed training, enabling multi-GPU support and mixed-precision training.
Here’s a demonstration of the Gradio app in action:
Get Dataset: Please review this and then update your data configuration file.
Configuration Files: Create and update your configuration files for data and training settings. Example files can be found in the config/
directory.
Run Training Script:
python train.py
This script will use the configurations specified in your YAML files to train the BART model and save the best-performing model.
python app.py
This command will start the Flask server, which includes the Gradio interface for interacting with your model.
http://localhost:5000
to view the Gradio interface embedded within the Flask application.You can also interact with the model via the Flask API. Here is an example of how to use cURL
to get a summary and rating:
curl -X POST http://localhost:5000/predict -H "Content-Type: application/json" -d '{"review_text": "Your review text here"}'
Response Example:
{
"summary": "Generated summary of the review text.",
"rating": 4
}
train.py
: Script to train the BART model with specified configurations.app.py
: Flask application integrating the Gradio interface.model.py
: Contains the BART model definition for review summarization and rating.templates/index.html
: HTML template for embedding the Gradio interface.config/data_info.yaml
: Configuration file for data-related settings.config/training_info.yaml
: Configuration file for training-related settings.transformers
: For model and tokenizer.torch
: PyTorch library for deep learning.flask
: Web framework for building the API.gradio
: Interface for interactive model deployment.pyyaml
: For reading configuration files.This project is licensed under the MIT License. See the LICENSE file for details.