木厶笔记存放站木厶笔记存放站
首页
文章
首页
文章
  • Web

    • 网络

      • 01-01-HTTP 基础
      • 01-02-HTTP1.1 首部字段
      • 01-03-HTTP 优化
      • 01-04-HTTPS
      • 01-05-TCP 和 UDP
      • 01-06-URI 和 URL
      • 01-07-Cookie 和 Session
      • 01-08-输入URL_至页面显示的过程
    • 安全

      • CSRF 攻击
      • JSON Web Token
      • sql 注入
      • XSS 攻击
    • 浏览器机制

      • 提升SEO
      • 浏览器内核
      • 浏览器的进程和线程
      • 浏览器缓存机制
      • 跨域
    • 场景实现

      • 优化大量图片加载
      • 扫码登录
    • 性能优化

      • 性能分析
      • 资源请求优化
      • 运行时优化
      • 重排重绘
      • 静态资源优化
  • HTML

    • 基础

      • Doctype
      • src 和 href
      • title 和 alt
    • HTML5

      • Audio 和 Video
      • Canvas
      • cookie session Storage localStorage
      • Drag
      • Server-Sent-Events
      • Svg
      • WebSocket
      • WebWorker
      • 标签
      • 离线缓存manifest
  • CSS

    • 基础

      • BFC
      • CSS 动画
      • CSS 架构
      • display float position
      • float 浮动
      • 元素不可见
      • 块级元素和行内元素
      • 层叠顺序与堆栈上下文
      • 引入外部CSS
      • 盒模型
      • 雪碧图
    • 布局

      • flex 布局--实用
      • flex 布局
      • Grid 布局--实用
      • Grid 布局
      • 三栏布局
      • 两栏布局
      • 响应式布局
      • 移动端布局
    • 实用

      • 修改svg图片颜色
      • 初始化 CSS
      • 换肤
      • 文本省略号
      • 消除 inline-block 间隙
      • 自定义字体
    • CSS-揭秘

      • 2-1 半透明边框
      • 2-2 多重边框
      • 2-3 灵活的背景定位
      • 2-4 边框内圆角
      • 2-5 条纹背景
      • 2-6 复杂的背景图案
      • 2-8 连续的图像边框
      • 3-1 自适应的椭圆
      • 3-2 平行四边形
      • 3-3 菱形图片
      • 3-4 切角效果
      • 3-5 梯形标签页
      • 3-6 简单的饼图
      • 4-1 单侧投影
      • 4-2 不规则投影
      • 4-3 染色效果
      • 4-4 毛玻璃效果
      • 4-5 折角效果
      • 5-1 连字符断行
      • 5-2 插入换行
      • 5-3 文本行的斑马线条
      • 5-6 华丽的 & 符号
      • 5-7 自定义下划线
      • 5-8 现实中的文字效果
      • 5-9 环形文字
      • 6-1 选用合适的鼠标光标
      • 6-2 扩大可点击区域
      • 6-3 自定义复选框
      • 6-4-5 通过阴影(模糊)来弱化背景
      • 6-6 滚动提示
      • 6-7 交互式的图片对比控件
      • 7-1 自适应内部元素
      • 7-2 精确控制表格列宽
      • 7-3 根据兄弟元素的数量来设置样式
      • 7-4 满幅的背景,定宽的内容
      • 7-5 垂直居中
      • 7-6 紧贴底部的页脚
      • 8-1 缓动效果
      • 8-2 逐帧动画
      • 8-3 闪烁效果
      • 8-4 打字动画
      • 8-5 状态平滑的动画
      • 8-6 沿环形路径平移的动画
    • CSS 选择器

      • 伪类
      • 属性选择器
      • 数学函数
      • 树结构伪类
      • 选择器优先级
      • 选择符
      • 逻辑选择器
    • 奇妙属性

      • @property
      • mask
  • JavaScript

    • 基础

      • 2 script标签
      • 3-1 语法
      • 3-3 var let const
      • 3-4-1 typeof 和 instanceof
      • 3-4-2 Undefined 类型
      • 3-4-3 Null 类型
      • 3-4-4 Boolean 类型
      • 3-4-5 Number 类型
      • 3-4-6 String 类型
      • 3-4-7 Symbol 类型
      • 3-4-8 Object 类型
      • 3-4-9 BigInt 类型
      • 3-5 操作符
      • 3-6 隐式转换
      • 4-1 原始值和引用值
      • 4-2 执行上下文与作用域
      • 4-3 垃圾回收
      • 5-1 Date
      • 5-2 RegExp
      • 5-3 原始值包装类型
      • 5-4-1 Global 单例内置对象
      • 5-4-2 Math 单例内置对象
      • 6-2 Array
      • 6-3 Map
      • 6-4 WeakMap
      • 6-5 Set
      • 6-6 WeakSet
      • 7-1 迭代器与生成器
      • 8-1-1 Object 属性
      • 8-1-2 Object 构造函数方法
      • 8-1-3 Object 语法增强和解构
      • 8-2 Object 创建
      • 8-3 原型和继承
      • 8-4 class 类
      • 8-5 this
      • 8-6 可选链和空值合并运算符
      • 9 proxy 代理与反射
      • 10-1 Function 函数
      • 10-2 apply call bind
      • 10-3 闭包
      • 10-4 私有变量
      • 10-5 扩展运算符和 rest
      • 11-1 event loop 事件循环机制
      • 11-2 Promise 期约
      • 11-3 async await
      • 12-1 BOM
      • 14-1-1 DOM 节点
      • 14-1-2 DOM Document 类型
      • 14-1-3 DOM Element 类型
      • 14-1-4 DOM Text 和 Comment 类型
      • 14-2 动态脚本和动态样式
      • 14-3 MutationObserver 监听节点变化
      • 14-4 Intersection Observer
      • 16-2-1 操控样式
      • 16-2-2 元素尺寸
      • 16-3 DOM 深度优先遍历
      • 17-1 事件流
      • 17-2 事件处理程序
      • 17-3 事件对象
      • 17-4-1 UI 和焦点事件
      • 17-4-2 鼠标和滚轮事件
      • 17-4-3 键盘与输入事件
      • 17-4-4 HTML5 事件
      • 17-4-5 设备和触摸事件
      • 17-6 模拟事件
      • 18-1 requestAnimationFrame&requestIdleCallback
      • 18-2 canvas
      • 19-1 表单基础
      • 19-2 文本框编程
      • 19-3 选择框编程
      • 19-5 富文本编辑器
      • 20-10 Perfromance API
      • 20-11-1 HTML 模板
      • 20-11-2 影子 DOM
      • 20-11-3 自定义元素
      • 20-14 MessageChannel和BroadcastChannel
      • 20-15 AbortController
      • 20-4 File API 与 Blob API
      • 20-7 Notifications API
      • 20-8 Page Visibility API
      • 20-9 Streams API
      • 23 JSON
      • 24-1 XMLHttpRequest
      • 24-5 Fetch API
      • 26-1 模块语法
      • 26-2 CommonJs 与 ES6 Module 的差异
    • 设计模式

      • 单例模式
      • 发布订阅模式
      • 工厂模式
      • 策略模式
      • 装饰器模式
      • 观察者模式
      • 适配器模式
    • 手写实现

      • apply call bind
      • flat
      • instanceof
      • JSONP
      • new
      • Promise API
      • trim
      • 数组去重
      • 柯里化
      • 深拷贝
      • 防抖和节流
    • 场景题

      • 映射 URL 参数
      • 模拟红绿灯
      • 请求-丢弃旧时序的请求
      • 请求-并发请求
      • 通过 value 找 key
      • 闭包题
    • 移动端

      • touch 事件
      • visualViewport
      • 像素
      • 移动端布局
      • 视口 Viewport
    • 实用

      • better-scroll 滚动组件
      • Proxy实现英文字母升降序
      • 区别数组和对象
      • 图片懒加载
      • 按首字母排序的列表
      • 控制粘贴板
    • PIXI

      • 1 Application
      • 2 Graphics
      • 3 loader
      • 4 Sprite
      • 5 Spine
      • 6 事件
      • 7 Renderer
  • Node

    • 基础

      • crypto 加密模块
      • fs 模块
      • http 模块
      • mysql 模块
      • redis 模块
    • 框架

      • express
      • koa2
    • 实用

      • @elasticelasticsearch
      • restful Mock 数据
      • 自定义 Mock 数据
    • 错误

      • pm2-watch报错 502
      • spawn 中文乱码
  • Jquery

    • 基础

      • 事件
      • 动画
      • 工具方法
      • 操作 dom
      • 获取元素
  • TypeScript

    • 环境

      • config
      • 环境
    • 基础

      • 01 原始类型和特殊类型
      • 02 字面量和类型拓宽
      • 03 interface 和 type
      • 04 数组和元组
      • 05 class(类)
      • 06 函数和重载
      • 07 联合类型和交叉类型
      • 08 泛型
      • 09 类型推断和类型断言
      • 10 匹配提取
      • 11 重新构造
      • 12 递归循环
      • 13数组长度做计数
      • 14 特殊特性
      • 15 内置高级类型
      • 16 inter extends
      • 17 协变和逆变
  • Vue

    • 环境

      • 安装
      • 自定义环境变量
    • 基础(2.x)

      • data
      • keep-alive
      • nextTick
      • props和sync
      • ref
      • v-for和v-if
      • 事件绑定
      • 动态组件
      • 动画
      • 循环渲染
      • 插槽 slot
      • 条件渲染
      • 样式绑定
      • 模板语法
      • 生命周期
      • 组件通讯
      • 自定义指令
      • 表单绑定
      • 计算属性和监听器
    • 基础(3.x)

      • Composition API
      • Script setup
      • Suspense
      • sync 语法糖
      • Teleport
      • vue3的升级点
      • 常用动画
      • 生命周期
    • router

      • vue-router 3
      • vue-router 4
    • vuex

      • pinia
      • vuex
      • vuex4
      • 刷新不丢失 vuex
    • 底层

      • MVVM
      • 双向绑定
      • 响应式
      • 模板编译
      • 渲染过程
      • 虚拟DOM和diff算法
    • 应用

      • 权限管理
  • React

    • 环境

      • 安装
      • 自定义环境变量
    • 基础

      • JSX
      • ref
      • SCU
      • 不可变数据 setState
      • 事件
      • 动画
      • 异步组件
      • 插槽
      • 条件
      • 生命周期
      • 组件公共逻辑抽离
      • 组件通讯
      • 表单
      • 逃离组件 portals
    • hooks

      • hooks
      • react-query
      • useClickOutSide —— 点击外面
      • useDebounce —— 防抖
      • useURLQueryParam —— 输入框值与search绑定
    • router

      • react-router-com 6
      • react-router-dom hooks
      • react-router-dom
    • redux

      • react-redux
      • redux 与 hooks
      • redux-persist
      • redux-thunk
      • redux-toolkit
      • redux
    • 底层

      • JSX本质
      • setState 原理
      • 合成事件
    • 实用

      • CSS-in-JS
      • 字体库
      • 数组优化为哈希表
      • 组件的子元素只能是规定的元素
  • Echarts

    • 基础

      • 仪表盘
      • 基础
      • 折线图
      • 散点图
      • 柱状图
      • 雷达图
      • 饼图
  • Electron

    • 环境

      • 基本安装
      • 集合 react
    • 基础

      • Dialog
      • 原生菜单
      • 右键菜单
      • 进程
  • 前端工程化

    • babel

      • babel7 实践
      • 工作原理
      • 生态
    • Browserslist

      • 基础
    • npm

      • npm模块安装机制
      • npm脚本
    • qiankun

      • 隔离原理
    • Vite

      • css
      • gzip-打包
      • vite config 常见配置
      • 环境
      • 静态资源
    • webpack

      • css环境
      • js&ts环境
      • loader
      • loader和plugin的区别及编写
      • plugin
      • splitChunks
      • webpack5--模块联邦
      • 代码压缩
      • 优化性能
      • 图片
      • 性能分析工具
      • 提高构建速度
      • 构建性能--并行
      • 构建性能--持久化缓存
      • 热更新
    • 代码提交规范

      • husky&lint-staged
  • Java

    • 环境

      • idea 创建 maven-archetype-webapp
      • spring boot web 的配置参考
      • spring 从初始配置
      • 与 git 集合
    • Web 基础

      • mybatis pager的应用
      • [] 和 List 和 Set
      • 全局 cors 跨域
      • 接口例子——登录
      • 文件接口——图片
      • 有效时间的唯一字符串
      • 自定义类
      • 通用的泛型服务端响应对象
  • Elastic

    • 基础

      • 分词规则
      • 基础概念
      • 基础语法
      • 环境搭建
    • 实用

      • canal——数据库准实时导入
      • logstash——数据库基于时间轴导入
  • Mysql

    • 基础

      • 字段操作
      • 数据库操作
      • 查询操作
      • 表操作
  • Python

    • 基础

      • 1-1-Number
      • 1-2-String
      • 1-3-list和tuple
      • 1-4-序列
      • 1-5-set
      • 1-6-dict
      • 2-1-运算符
      • 2-2-对象比较
      • 3-1-条件判断
      • 3-2-循环
      • 4-1-模块-包
      • 4-2-模块-__init__
      • 4-3-模块-内置变量
      • 4-4-模块-导入
      • 5-1-解包
      • 5-2-函数参数
      • 5-3-作用域
      • 5-4-高阶函数&三元表达式
      • 6-1-类
      • 6-2-类的继承
      • 6-3-枚举
      • 7-1-json
      • 8-1-文件IO
      • 8-2-pathlib
  • Flutter

    • 环境

      • 创建及运行项目
      • 第三方库
      • 运行环境
      • 静态资源
    • Dart

      • dynamic var object
      • List
      • Map
      • Number
      • Set
      • String
      • URI
      • 函数
      • 变量
      • 库
      • 异步
      • 流程控制语句
      • 类
      • 运算符
    • 基础

      • 1 Widget
      • 2 StatelessWidget & StatefulWidget
      • 3 State生命周期
      • 4 状态管理
      • 5 路由
    • 基础组件

      • Button
      • Text
      • 单选开关和复选框
      • 图片及ICON
      • 表单
      • 输入框
      • 进度指示器
    • 布局组件

      • 1 布局约束
      • 2 线性布局(Row & Column)
      • 3 弹性布局(Flex)
      • 4 流式布局(Wrap & Flow)
      • 5 层叠布局(Stack & Positioned)
      • 6 绝对定位(Align)
      • 7 LayoutBuilder & AfterLayout
    • 容器类组件

      • Container
      • Scaffold
      • 变换-Transform
      • 空间适配-FittedBox
      • 裁剪-Clip
      • 装饰-DecoratedBox
    • 可滚动组件

      • SingleChildScrollView
      • 通用属性
  • Git

    • 基础

      • Git commit message 规范
    • 实用

      • cherry-pick
      • stash
      • 分支&修改最近一次commit
      • 撤销
  • 算法

    • 基础

      • leetcode
      • 时间复杂度
    • 收录

      • 阿拉伯数字转中文
    • 数据结构

      • 哈希表、集合
      • 并查集
      • 数组、链表、跳表
      • 栈、队列
      • 树、二叉树、二叉搜索树
    • 算法

      • DFS 和 BFS
      • 二分查找
      • 二叉树路径问题题目
      • 动态规划2--基础题
      • 动态规划2--背包问题
      • 区间类题目
      • 岛屿类题目
      • 排列组合子集类题目
      • 排序算法
      • 摩尔投票
      • 递归
      • 链表
  • 部署

    • centos8

      • bitwarden_rs
      • ElasticSearch
      • git
      • https及http2
      • mac 向 centos 传输文件
      • nginx
      • nvm
    • 其他

      • docker
      • sitemap
      • 重装系统
  • 图形学

    • 基础

      • 1 三角函数
      • 2 斜率k
      • 3 向量
      • 4 矩阵
  • AI

    • huggingface

      • 01-01-使用和下载模型
      • 01-02-数据集
      • 01-03-文本分类任务微调
    • langchain

      • 01-01-Message
      • 01-02-promptTemplate-提示词模版
      • 01-03-outputParser-输出解析器
      • 01-04-fewShot
      • 02-01-LECL&chain
      • 02-02-stream
      • 03-01-debug
      • 04-01-多模态
      • 05-01-tool
      • 06-01-agent
      • 07-01-embedding基础
      • 07-02-向量数据库
      • 08-01-RAG
    • langgraph

      • 1-01-图
      • 1-02-1-状态
      • 1-02-2-消息状态
      • 1-03-节点
      • 1-04-边
      • 1-05-Send
      • 1-06-Command
      • 1-07-配置-configurable
      • 1-08-1-内存持久性
      • 1-08-2-删除持久消息
      • 1-08-3-向量数据库
      • 1-08-4-持久上下文token优化
      • 1-09-1-工具
      • 1-09-2-大量工具优化
      • 1-10-1-人机交互-interrupt
      • 1-10-2-人机交互-interrupt_before
      • 1-11-1-流式输出
      • 1-11-2-自定义流式传输
      • 1-11-3-输出特定的流式消息
      • 1-12-1-子图
      • 1-13-1-ReAct
    • mcp

      • mcp客户端
      • 服务器——python
      • 服务器——ts
    • prompt

      • 提示词策略
      • 输出output

文本分类模型微调

介绍

基于BERT(Bidirectional Encoder Representations from Transformers)的语义相似度算法是一种强大的自然语言处理工具,它能够理解文本之间的深层语义关系。它通过双向Transformer编码器来预训练深层的语言表示。

实现语义相似度算法的步骤:

预训练BERT模型:首先需要在一个大规模的文本语料库上预训练BERT模型,使其学习语言的通用表示。

微调:在特定的任务数据集上对预训练的BERT模型进行微调,以适应特定的应用场景。

提取特征:使用微调后的模型提取文本的特征向量。

计算相似度:通过比较这些特征向量来计算文本之间的语义相似度。

环境准备

  1. 安装 huggingface 库
pip install transformers datasets
  1. 下载预训练模型
  • 一般使用中文的 bert 模型(bert-base-chinese)
from transformers import BertTokenizer, BertForSequenceClassification

#替换为你选择的模型名称
model_name = "bert-base-chinese"
#指定模型保存路径
cache_dir = "./model_cache"

#下载并加载模型和分词器到指定文件夹
model = BertForSequenceClassification.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
  1. 下载微调使用的数据集
  • huggingface 筛选文本分类数据集:在左侧筛选栏里选择 tasks,然后选择 text-classification(注意这不是必须的,有些文本分类数据集没有被打上 text-classification 标签,筛选了反而搜索不到)
  • 下面使用 weibo 的数据集
from datasets import load_dataset

# 加载在线的数据集
dataset = load_dataset(path="kuroneko5943/weibo16" cache_dir=r"cache_dataset/")
# 缓存数据到本地
dataset.save_to_disk(dataset_path="cache_dataset/mydata/")
  • 数据集会有两个重要的部分,一个是文本,一个是文本对应的分类值。一般会如下,但是不同数据集 key 会不一样,在后面代码中需要动态调整
    • label:分类值
    • text:文本
label,text
1,崩溃了。
7,文革之后还能保留这个习俗也算异数了。
7,其实更好。

训练

1、自定义训练集类

  • split入参是要传入数据集类型,一般类型分为 训练数据集、验证数据集、测试数据集(有的数据集可能只有某一种类别的数据集,看情况下 MyDataset 类)
  • __getitem__() 方法需要返回使用数据集的文本和分类值,根据模型对应 key 返回就行

from torch.utils.data import Dataset
from datasets import load_dataset

class MyDataset(Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path="csv", data_files=f"data/Weibo/{split}.csv", split=split)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, item):
        text = self.dataset[item]["text"] # 文本
        label = self.dataset[item]["label"] # 分类值
        return text,label

2、定义下游任务类

  • torch.nn.Linear(768, 8):将模型的 768 维特征转换为数据集的类别数特征
    • 768 是 bert-base-chinese 模型的 out_features,可以通过 print(pretrained) 看到使用模型的 out_features
    • 8是数据集分类的种类数,示例使用的 weibo 数据集会将文本分为8种类别
  • forward
    • input_ids:这是输入文本转成词向量后的 id 数据,即调用模型的token.batch_encode_plus进行文本 -> 词向量 后的结果
    • attention_mask:这是标出词向量哪一部分是有用。这是因为每一个文本都会转换为相同长度的词向量,但是文本不是同样长的,不够长的文本被转换为词向量后剩余部分会用 0 补齐。attention_mask就是标记非 0 部分的词向量
from transformers import BertModel
import torch

# 模型加载根据环境动态选择 cpu 和 gpu
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
pretrained = BertModel.from_pretrained(r"xxx\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f").to(DEVICE)

#定义下游任务(将主干网络所提取的特征进行分类)
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 768维特征 -> 2分类
        self.fc = torch.nn.Linear(768, 8)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad(): # 冻结BERT权重
            out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        out = self.fc(out.last_hidden_state[:,0]) # 取[CLS]标记的输出
        out = out.softmax(dim=1) # 概率分布
        return out

3、训练

  • 如果有验证数据集,最好可以在训练完一轮后,使用验证集进行验证,当验证的评分大于历史最好评分(或者验证损失值小于历史最小损失值,二者去一个去验证就行),才保存这轮结果。不然可能会训练过拟合,训练效果反而不好
from torch.utils.data import Dataset, DataLoader
from datasets import load_from_disk
from transformers import BertModel, BertTokenizer
from torch.optim import AdamW
import torch
from MyData import MyDataset
from net import Model

# 训练轮数
EPOCH = 100
# 检测使用GPU还是CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载预训练模型
pretrained = BertModel.from_pretrained(r"/xxxl/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea").to(DEVICE)
# 加载tokenizer
token = BertTokenizer.from_pretrained(r"/xxx/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0e")
    
# 数据批处理函数
def collate_fn(data):
    sentes = [i[0] for i in data]
    label = [i[1] for i in data]
    # 对句进行编码成id,
    data = token.batch_encode_plus(batch_text_or_text_pairs=sentes,
                            truncation=True,
                            padding="max_length",
                            max_length=500,
                            return_tensors="pt",
                            return_length=True)
    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    labels = torch.LongTensor(label)
    return input_ids,attention_mask,token_type_ids,labels
#创建数据集
train_dataset = MyDataset("train") # 训练数据集
val_dataset = MyDataset("validation") # 验证数据集
#创建dataloader
train_laoder = DataLoader(dataset=train_dataset,
                          batch_size=100,
                          shuffle=True,
                          drop_last=True,
                          collate_fn=collate_fn)


model = Model().to(DEVICE)
optimizer = AdamW(model.parameters(),lr=5e-4)
loss_func = torch.nn.CrossEntropyLoss()

# 训练循环 
for epoch in range(EPOCH):
    # 训练
    model.train()
    for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(train_laoder):
        input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(
            DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
        
        # 前向传播
        out = model(input_ids,attention_mask,token_type_ids)
        # 计算损失
        loss = loss_func(out,labels)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i%5==0:
            out = out.argmax(dim=1)
            acc = (out==labels).sum().item()/len(labels)
            print(epoch,i,loss.item(),acc)
            
    #模型验证
    model.eval()
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(val_loader):
        input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(
            DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
        out = model(input_ids, attention_mask, token_type_ids)

        loss = loss_func(out, labels)

        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        sum_val_loss = sum_val_loss + loss
        sum_val_acc = sum_val_acc + accuracy
    avg_val_loss = sum_val_loss / len(val_loader)
    avg_val_acc = sum_val_acc / len(val_loader)
    print(f"val==>epoch:{epoch},avg_val_loss:{avg_val_loss}, avg_val_acc:{avg_val_acc}")

    # 只有当评分大于历史最好评分,才保存数据结果(验证集的作用)
    #if avg_val_acc > best_avg_val_acc:
    #   torch.save()
    #   best_avg_val_acc = avg_val_acc
    torch.save(model.state_dict(),f"cache_dataset/train_data/{epoch}bert01.pth")
    print(epoch,"参数保存成功!")

、使用训练数据,对用户输入进行分类

names = ["like","disgust","happiness","sadness","anger","surprise","fear","none"]

def test():
    model.load_state_dict(torch.load("cache_dataset/train_data/0bert01.pth"))
    model.eval()
    while True:
        data = input("请输入测试数据(输入'q'退出):")
        if data == "q":
            print("测试结束")
            break
        input_ids, attention_mask, token_type_ids= collate_fn(data)
        input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE), attention_mask.to(
            DEVICE), token_type_ids.to(DEVICE)

        with torch.no_grad():
            out = model(input_ids,attention_mask,token_type_ids)
            out = out.argmax(dim=1)
            print("模型判定:",names[out],"\n")
            
test()
最近更新: 2025/8/24 12:59
Contributors: kingmusi
Prev
01-02-数据集