java(kotlin) ai框架djl

heeheeai 2024-06-24 15:31:01 阅读 92

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。

PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。

TensorFlow:由Google开发的另一个流行的开源机器学习框架。

DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

<!-- djl--> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> <version>0.28.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.28.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-model-zoo</artifactId> <version>0.28.0</version> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>basicdataset</artifactId> <version>0.28.0</version> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>model-zoo</artifactId> <version>0.28.0</version> </dependency> <!-- /djl-->

Java DJL 架构图

┌──────────────────────────────┐│ ModelZoo │├──────────────────────────────┤│ Model │└───────────────┬──────────────┘ │ ┌─────────▼─────────┐ │ Engine │ └───────┬─┬─────────┘ │ │ ┌───────▼─▼─────────┐ │ NDManager │ └───────┬─┬─────────┘ │ │ ┌─────────▼─▼───────────┐ │ Dataset └─────────┬─────────────┘ │ ┌─────────▼─────────────┐ │ Trainer / Predictor │ └───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model

ModelZoo:提供多种预训练模型

ModelZoo 的功能
模型发现与下载ModelZoo 提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。 模型加载ModelZoo 提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。加载模型时,可以指定模型的名称、版本、以及模型的参数配置。 模型管理ModelZoo 帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。

示例

import ai.djl.Applicationimport ai.djl.Modelimport ai.djl.ModelExceptionimport ai.djl.modality.Classificationsimport ai.djl.modality.cv.Imageimport ai.djl.repository.zoo.Criteriaimport ai.djl.repository.zoo.ModelZooimport ai.djl.translate.TranslateExceptionobject ModelZooExample { @Throws(ModelException::class, TranslateException::class) @JvmStatic fun main(args: Array<String>) { // 定义模型的标准 val criteria: Criteria<Image, Classifications> = Criteria.builder() .optApplication(Application.CV.IMAGE_CLASSIFICATION) // 应用场景:图像分类 .setTypes(Image::class.java, Classifications::class.java) // 输入输出类型 .optFilter("backbone", "resnet50") // 模型过滤条件 .build() // 从 ModelZoo 加载模型 val model: Model = ModelZoo.loadModel(criteria) // 使用模型进行推理 // ... }}

ModelZoo 的类与接口

ModelZoo:核心类,提供模型的下载和加载功能。Criteria:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。ModelLoader:用于实际执行模型的下载和加载操作。

Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。

ai.djl.ModelZoo

Key Methods:
Model loadModel(Criteria<?, ?> criteria): Loads a model based on the provided criteria.ModelInfo getModel(ModelId modelId): Retrieves information about a specific model using its ModelId.Set<ModelId> listModels(ZooModel<?, ?> model): Lists all models in the zoo that match the given model.

ai.djl.ModelInfo Interface

ModelInfo provides metadata about a model, including its name, description, and input/output information.

Key Methods:
String getName(): Returns the name of the model.String getDescription(): Provides a description of the model.Shape getInputShape(): Returns the shape of the input tensor.Shape getOutputShape(): Returns the shape of the output tensor.

ai.djl.ModelId Class

ModelId uniquely identifies a model in the model zoo. It includes information about the model’s group, name, and version.

Key Fields:
String getGroup(): Gets the group name of the model.String getName(): Gets the name of the model.String getVersion(): Gets the version of the model.

ai.djl.Application Enum

Application enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.

Key Values:
CV.IMAGE_CLASSIFICATIONCV.OBJECT_DETECTIONNLP.TEXT_CLASSIFICATION

ai.djl.Criteria Class

Criteria is a builder for creating criteria objects used to filter and load models.

Key Methods:
static Builder<?, ?> builder(): Creates a new builder instance.Criteria<I, O> optApplication(Application application): Sets the application type.Criteria<I, O> optEngine(String engine): Specifies the engine to use (e.g., MXNet, PyTorch)
example

import ai.djl.Modelimport ai.djl.ModelExceptionimport ai.djl.modality.Classificationsimport ai.djl.modality.cv.Imageimport ai.djl.modality.cv.ImageFactoryimport ai.djl.ndarray.NDListimport ai.djl.translate.TranslateExceptionimport ai.djl.translate.Translatorimport ai.djl.translate.TranslatorContextimport java.io.IOExceptionimport java.nio.file.Pathsobject DjlExample { @JvmStatic fun main(args: Array<String>) { // 模型路径 val modelDir = Paths.get("models") val modelName = "resnet18" try { Model.newInstance(modelName).use { model -> // 加载模型 model.load(modelDir) // 加载输入图像 val img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")) // 获取预测器 val predictor = model.newPredictor(MyTranslator()) // 执行推理 val result = predictor.predict(img) println(result) } } catch (e: IOException) { e.printStackTrace() } catch (e: ModelException) { e.printStackTrace() } catch (e: TranslateException) { e.printStackTrace() } } // 自定义 Translator private class MyTranslator : Translator<Image?, Classifications?> { override fun processInput(ctx: TranslatorContext?, input: Image?): NDList { return NDList(input!!.toNDArray(ctx!!.ndManager)) } override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications { val probabilitiesNDArray = list.singletonOrThrow().softmax(1) val labels: List<String> = List(100) { "name$it" } return Classifications(labels, probabilitiesNDArray) } }}

2. Dataset

常见的数据集类型:

RandomAccessDataset: RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。 IterableDataset: IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。 RecordDataset: RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

DJL 的数据集组件提供的功能包括:

数据加载和预处理: 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。提供数据预处理功能,如归一化、数据增强、特征提取等。 批处理(Batching): 支持将数据分成小批次进行处理,适用于大规模数据集的训练。提供灵活的批处理策略,可根据需要进行自定义。 数据变换(Transformations): 提供多种数据变换功能,如图像变换、文本处理、数值处理等。支持链式调用,将多个变换操作组合在一起,形成数据处理管道。 数据加载器(DataLoader): DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。支持多线程数据加载,提高数据处理效率。

Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

import ai.djl.Model;import ai.djl.ModelException;import ai.djl.inference.Predictor;import ai.djl.modality.Classifications;import ai.djl.modality.cv.Image;import ai.djl.modality.cv.ImageFactory;import ai.djl.repository.zoo.Criteria;import ai.djl.repository.zoo.ModelZoo;import ai.djl.translate.TranslateException;import java.io.IOException;import java.nio.file.Paths;public class DjlExample { public static void main(String[] args) throws IOException, ModelException, TranslateException { // 加载模型 Criteria<Image, Classifications> criteria = Criteria.builder() .optEngine("TensorFlow") // 选择引擎 .setTypes(Image.class, Classifications.class) .optModelPath(Paths.get("path/to/model")) .build(); try (Model model = ModelZoo.loadModel(criteria); Predictor<Image, Classifications> predictor = model.newPredictor()) { // 加载图像 Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")); // 进行推理 Classifications result = predictor.predict(img); System.out.println(result); } }}

import ai.djl.Application;import ai.djl.Model;import ai.djl.basicdataset.cv.classification.FashionMnist;import ai.djl.engine.Engine;import ai.djl.metric.Metrics;import ai.djl.ndarray.NDArray;import ai.djl.ndarray.NDManager;import ai.djl.training.DefaultTrainingConfig;import ai.djl.training.EasyTrain;import ai.djl.training.Trainer;import ai.djl.training.dataset.Batch;import ai.djl.training.dataset.Dataset;import ai.djl.training.listener.TrainingListener;import ai.djl.training.loss.Loss;import ai.djl.training.optimizer.Optimizer;import ai.djl.training.tracker.Tracker;import ai.djl.translate.TranslateException;import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample { public static void main(String[] args) throws IOException, TranslateException { NDManager manager = NDManager.newBaseManager(); FashionMnist fashionMnist = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(32, true) // 32 is the batch size .optLimit(Long.MAX_VALUE) // Use this to limit the number of samples .build(); fashionMnist.prepare(); Model model = Model.newInstance("fashion-mnist-model"); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()) .addTrainingListeners(TrainingListener.Defaults.logging()); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new long[]{1, 28, 28}); // Example shape for image data Metrics metrics = new Metrics(); trainer.setMetrics(metrics); for (Batch batch : trainer.iterateDataset(fashionMnist)) { EasyTrain.trainBatch(trainer, batch); trainer.step(); batch.close(); } trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer)); } }}

3. Engine 和 NDManager

Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

Using DJL Engine

import ai.djl.Modelimport ai.djl.ModelExceptionimport ai.djl.ndarray.NDArrayimport ai.djl.ndarray.NDListimport ai.djl.ndarray.types.Shapeimport ai.djl.translate.Batchifierimport ai.djl.translate.TranslateExceptionimport ai.djl.translate.Translatorimport ai.djl.translate.TranslatorContextimport java.io.IOExceptionimport java.nio.file.Pathsobject DJLEngineExample { @Throws(ModelException::class, TranslateException::class, IOException::class) @JvmStatic fun main(args: Array<String>) { // Initialize the model val model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine // Load a pre-trained model model.load(Paths.get("path/to/your/model")) // Ensure the path is correct // Define a translator for data preprocessing and postprocessing val translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> { override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList { val manager = ctx.ndManager val array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessary return NDList(array) } override fun processOutput(ctx: TranslatorContext, list: NDList): Float { // Assuming the output is a single scalar value return list[0].getFloat() // Use getFloat() to get the scalar value } override fun getBatchifier(): Batchifier? { return null // Or implement batching if needed } } model.newPredictor(translator).use { predictor -> val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shape val output = predictor.predict(input) println("Prediction: $output") } }}

Overview of NDManager
Key Features of NDManager:
Memory Management: Automates the process of memory allocation and deallocation for NDArrays.Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
Using NDManager

import ai.djl.ndarray.NDManagerobject NDManagerExample { @JvmStatic fun main(args: Array<String>) { NDManager.newBaseManager().use { manager -> val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f)) println("Array: $array") // Perform operations val result = array.add(2.0f) println("Result: $result") } // No need to explicitly free the memory, it's handled by the NDManager }}

4. Trainer 和 Predictor
Trainer 类

提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

主要功能包括:

模型的训练和验证管理优化器和损失函数提供易于使用的训练循环

代码演示

以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

import ai.djl.Modelimport ai.djl.basicdataset.cv.classification.FashionMnistimport ai.djl.basicmodelzoo.basic.Mlpimport ai.djl.ndarray.types.Shapeimport ai.djl.training.DefaultTrainingConfigimport ai.djl.training.TrainingConfigimport ai.djl.training.dataset.Datasetimport ai.djl.training.dataset.RandomAccessDatasetimport ai.djl.training.listener.LoggingTrainingListenerimport ai.djl.training.listener.TrainingListenerimport ai.djl.training.loss.Lossimport ai.djl.training.optimizer.Optimizerimport ai.djl.training.tracker.FixedPerVarTrackerimport ai.djl.training.util.ProgressBarimport ai.djl.translate.TranslateExceptionimport java.io.IOExceptionimport java.nio.file.Pathsobject DjlTrainerDemo { @Throws(IOException::class, TranslateException::class) @JvmStatic fun main(args: Array<String>) { // Load dataset val trainDataset: RandomAccessDataset = FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build() trainDataset.prepare(ProgressBar()) // Define model val model = Model.newInstance("mlp") model.block = Mlp(28 * 28, 10, intArrayOf(128, 64)) // Define training configuration val config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optOptimizer( Optimizer.sgd() .setLearningRateTracker( FixedPerVarTracker.builder() .setDefaultValue(0.01f) .build() ).build() ) .addTrainingListeners(LoggingTrainingListener()) model.newTrainer(config).use { trainer -> trainer.initialize(Shape(1, (28 * 28).toLong())) for (epoch in 0..9) { for (batch in trainer.iterateDataset(trainDataset)) { trainer.step() batch.close() } trainer.notifyListeners { listener: TrainingListener -> listener.onEpoch(trainer) } } model.save(Paths.get("model"), "mlp") } }}

Predictor 类

用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

主要功能包括:

加载模型进行推理处理输入和输出数据的转换

代码演示

import ai.djl.Modelimport ai.djl.modality.Classificationsimport ai.djl.ndarray.NDArrayimport ai.djl.ndarray.NDListimport ai.djl.ndarray.NDManagerimport ai.djl.ndarray.types.Shapeimport ai.djl.translate.Batchifierimport ai.djl.translate.TranslateExceptionimport ai.djl.translate.Translatorimport ai.djl.translate.TranslatorContextimport java.io.IOExceptionimport java.nio.file.Pathsobject DjlPredictorDemo { @Throws(IOException::class, TranslateException::class) @JvmStatic fun main(args: Array<String>) { // Load model val model = Model.newInstance("mlp") model.load(Paths.get("model"), "mlp") // Define Translator val translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> { override fun processInput(ctx: TranslatorContext, input: NDArray): NDList { return NDList(input.reshape(Shape(1, (28 * 28).toLong()))) } override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications { // Assuming the output NDArray is the first element in NDList val probabilities = list.singletonOrThrow() return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels } override fun getBatchifier(): Batchifier { return Batchifier.STACK } } model.newPredictor(translator).use { predictor -> val manager = NDManager.newBaseManager() val array = manager.ones(Shape(1, (28 * 28).toLong())) val classifications = predictor.predict(array) println(classifications) } }}



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。