分享

Mahout协同过滤框架Taste的源码分析(1)



问题导读:
1.Mahout如何优化内存开销?
2.Taste如何实现基于用户的推荐?







Taste
相关简介及使用参见协同过滤原理与Mahout实现

包结构
  • Taste实现
    org.apache.mahout.cf.taste

  • common
    公有类, 包括: 异常、数据刷新接口、权重常量

  • eval
    定义构造器接口, 类似于工厂模式

  • model
    定义数据模型接口

  • transforms
    定义数据转换的接口

  • similarity
    定义相似度算法的接口

  • neighborhood
    定义近邻算法的接口

  • recommender
    定义推荐算法的接口

  • impl
    基于单机内存的推荐流程实现

  • hadoop
    基于集群的分布式推荐流程实现


偏好

基本接口
偏好信息由用户ID、物品ID、偏好分值组成

  1. public interface Preference {
  2.     long getUserID();
  3.     long getItemID();
  4.     float getValue();
  5.     void setValue(float value);
  6. }
复制代码


偏好数组
  1. // 对于偏好信息的存储
  2. // 为了节省内存, Mahout没有直接采用数组
  3. public interface PreferenceArray extends Cloneable, Serializable, Iterable<Preference> {
  4.     void setUserID(int i, long userID);
  5.     void setItemID(int i, long itemID);
  6.     void setValue(int i, float value);
  7. }
复制代码


具体实现-用户偏好信息
  1. public final class GenericUserPreferenceArray implements PreferenceArray {
  2.     private long id;
  3.     // 为节省内存, 使用数组存储
  4.     private final long[] ids;
  5.     private final float[] values;
  6.     @Override
  7.     public long getUserID(int i) {
  8.         return id;
  9.     }
  10.     @Override
  11.     public long getItemID(int i) {
  12.         return ids[i];
  13.     }
  14.     @Override
  15.     public void setItemID(int i, long itemID) {
  16.         ids[i] = itemID;
  17.     }
  18. }
复制代码



具体实现-物品偏好信息
  1. // 与用户偏好信息数组的实现相同
  2. public final class GenericItemPreferenceArray implements PreferenceArray {
  3.     private final long[] ids;
  4.     private long id;
  5.     private final float[] values;
  6.     @Override
  7.     public long getItemID(int i) {
  8.         return id;
  9.     }
  10.     @Override
  11.     public long getUserID(int i) {
  12.         return ids[i];
  13.     }
  14.     @Override
  15.     public void setUserID(int i, long userID) {
  16.         ids[i] = userID;
  17.     }
  18. }
复制代码



内存问题
在 Java 中,一个对象占用的字节数 = 基本的8字节 + 基本数据类型所占的字节 + 对象引用所占的字节。

基本的8字节
在 JVM 中,每个对象(数组除外)都有一个头,这个头有两个字,第一个字存储对象的一些标志位信息,如:锁标志位、经历了几次 gc 等信息;第二个字节是一个引用,指向这个类的信息。JVM 为这两个字留了8个字节的空间。这样一来的话,new Object() 就占用了8个字节,那怕它是个空对象。

基本数据类型
  1. byte/boolean   1bytes
  2. char/short     2bytes
  3. int/float      4bytes
  4. double/long    8bytes
复制代码


对象引用
  1. reference      4bytes
复制代码



内存占用对比
一个GenericPreference对象需要28字节, 其中userID用8字节, itemID用8字节, value用4字节, 还有基本的8字节。

Array of Preferences方式
2-1.png
基本8字节 + 一个Preference对象的28字节 + 4个空引用的12字节 = 48字节

PreferenceArray方式
2-2.png
基本8字节 + userID的8字节 + itemID的8字节 + value的4字节 = 28字节



数据
基本接口
  1. // 用户偏好信息的压缩表示, 为各类推荐算法提供了对数据的高效访问
  2. public interface DataModel extends Refreshable, Serializable {
  3.     // 用于计算的一些常用方法
  4.     LongPrimitiveIterator getUserIDs() throws TasteException;
  5.     PreferenceArray getPreferencesFromUser(long userID) throws TasteException;
  6.     FastIDSet getItemIDsFromUser(long userID) throws TasteException;
  7.     Float getPreferenceValue(long userID, long itemID) throws TasteException;
  8. }
复制代码



具体实现–基于内存
  1. public final class GenericDataModel extends AbstractDataModel {
  2.     // 存储用户标示, 物品标示, 偏好信息
  3.     private final long[] userIDs;
  4.     private final long[] itemIDs;
  5.     private final FastByIDMap<PreferenceArray> preferenceFromUsers;
  6.     public GenericDataModel(FastByIDMap<PreferenceArray> userData) {
  7.         ......
  8.     }
  9. }
复制代码



具体实现–基于文件
  1. public class FileDataModel extends AbstractDataModel {
  2.     // 分隔符
  3.     private static final char[] DELIMIETERS = { ',', '\t' };
  4.     // 数据文件
  5.     private final File dataFile;
  6.     // 用于存储用户, 物品, 偏好信息
  7.     private DataModel delegate;
  8.     // 查找相同目录下的增量文件进行处理
  9.     // 增量文件格式要求第一个'.'符号前的名字相同
  10.     // 比如: /foo/data.txt.gz
  11.     // /foo/data.1.txt.gz, /foo/data.2.txt.gz
  12.     protected DataModel buildModel() throws IOException {
  13.         processFile()
  14.     }
  15.     // 逐行处理文件
  16.     // 将内容转换为PreferenceArray以便于构建GenericDataModel对象
  17.     // 要求文件格式为: userID,itemID[,preference[,timestamp]]
  18.     protected void processLine(String line, FastByIDMap<?> data,
  19.         FastByIDMap<FastByIDMap<Long>> timestamps, boolean fromPriorData) {
  20.             ...
  21.     }
  22. }
复制代码

内存问题
Java自带的HashMap和HashSet占用内存较大, 因此Mahout对这两个常用的数据结构也做了为减少内存开销的精简实现。

FastIDSet的每个元素平均占14字节, 而HashSet而需要84字节;
FastByIDMap的每个Entry占28字节, 而HashMap则需要84字节。
改进如此显著的原因在于:

* 和HashMap一样, FastByIDMap也是基于hash的。不过FastByIDMap使用的是线性探测来解决hash冲突, 而不是分割链;
* FastByIDMap的key和值都是long类型, 而不是Object, 这是基于节省内存开销和改善性能所作的改良;
* FastByIDMap类似于一个缓存区, 它有一个maximum size的概念, 当我们添加一个新元素的时候, 如果超过了这个size, 那些使用不频繁的元素就会被移除。



相似度
基本接口推测用户对某个物品的偏好
  1. public interface PreferenceInferrer extends Refreshable {
  2.     float inferPreference(long userID, long itemID) throws TasteException;
  3. }
复制代码


用户相似度
  1. public interface UserSimilarity extends Refreshable {
  2.     // 返回值范围 -1.0 ~ 1.0, Double.NaN表示相似度未知
  3.     double userSimilarity(long userID1, long userID2) throws TasteException;
  4.     // 设置PreferenceInferrer接口的实现类, 以用于推测偏好
  5.     void setPreferenceInferrer(PreferenceInferrer inferrer);
  6. }
复制代码


物品相似度
  1. public interface ItemSimilarity extends Refreshable {
  2.     // 返回值范围 -1.0 ~ 1.0, Double.NaN表示相似度未知
  3.     double itemSimilarity(long itemID1, long itemID2) throws TasteException;
  4.     double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException;
  5.     long[] allSimilarItemIDs(long itemID) throws TasteException;
  6. }
复制代码


抽象类实现物品相似度的抽象类
  1. public abstract class AbstractItemSimilarity implements ItemSimilarity {
  2.     private final DataModel dataModel;
  3.     @Override
  4.     public long[] allSimilarItemIDs(long itemID) throws TasteException {
  5.         // 从dataModel中获取所有的物品
  6.         // 将相似度不为NaN的物品ID添加到返回数组中
  7.         FastIDSet allSimilarItemIDs = new FastIDSet();
  8.         LongPrimitiveIterator allItemIDs = dataModel.getItemIDs();
  9.         while (allItemIDs.hasNext()) {
  10.             long possiblySimilarItemID = allItemIDs.nextLong();
  11.             if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) {
  12.                 allSimilarItemIDs.add(possiblySimilarItemID);
  13.             }
  14.         }
  15.         return allSimilarItemIDs.toArray();
  16.     }
  17. }
复制代码



实现用户、物品相似度的抽象类
  1. // 继承AbstractItemSimilarity使得该类可以计算物品相似度
  2. // 实现UserSimilarity接口使得该类可以计算用户相似度
  3. abstract class AbstractSimilarity extends AbstractItemSimilarity implements UserSimilarity {
  4.     // 是否考虑重叠部分对结果的影响
  5.     private final boolean weighted;
  6.     // 是否取标准差
  7.     private final boolean centerData;
  8.     @Override
  9.     public double userSimilarity(long userID1, long userID2) throws TasteException {
  10.         // 计算 count, sumXY, sumX2, sumY2, sumXYdiff2
  11.         // 然后调用computeResult获取相似度
  12.         ......
  13.     }
  14.     @Override
  15.     public final double itemSimilarity(long itemID1, long itemID2) throws TasteException {
  16.         // 计算 count, sumXY, sumX2, sumY2, sumXYdiff2
  17.         // 然后调用computeResult获取相似度
  18.         ......
  19.     }
  20.     // 具体在各个子类中实现
  21.     abstract double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2);
  22. }
复制代码



具体实现欧几里得距离
  1. public final class EuclideanDistanceSimilarity extends AbstractSimilarity {
  2.     @Override
  3.     double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2) {
  4.         // 这里对普通的欧式距离进行了修正
  5.         return 1.0 / (1.0 + Math.sqrt(sumXYdiff2) / Math.sqrt(n));
  6.     }
  7. }
复制代码



皮尔逊距离
  1. public final class PearsonCorrelationSimilarity extends AbstractSimilarity {
  2.     public PearsonCorrelationSimilarity(DataModel dataModel, Weighting weighting) throws TasteException {
  3.         // 取均值
  4.         super(dataModel, weighting, true);
  5.     }
  6.     @Override
  7.     double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2) {
  8.         // 皮尔逊距离用于度量两个维度之间的线性相关性
  9.         // 统计学的意义为两个序列协方差与两者方差乘积的比值
  10.         if (n == 0) {
  11.             return Double.NaN;
  12.         }
  13.         double denominator = Math.sqrt(sumX2) * Math.sqrt(sumY2);
  14.         if (denominator == 0.0) {
  15.             return Double.NaN;
  16.         }
  17.         return sumXY / denominator;
  18.     }
  19. }
复制代码



邻居

基本接口
  1. // 用户邻居接口
  2. public interface UserNeighborhood extends Refreshable {
  3.     // 需要根据用户ID给出其邻居用户的ID列表
  4.     long[] getUserNeighborhood(long userID) throws TasteException;
  5. }
复制代码



抽象类
  1. // 规定使用指定的相似度实现类从指定的数据模型中查找邻居用户
  2. // getUserNeighborhood方法在继承者中实现
  3. abstract class AbstractUserNeighborhood implements UserNeighborhood {
  4.     private final UserSimilarity userSimilarity;
  5.     private final DataModel dataModel;
  6. }
复制代码



具体实现基于个数
  1. public final class NearestNUserNeighborhood extends AbstractUserNeighborhood {
  2.     // 邻居个数
  3.     private final int n;
  4.     // 最小距离
  5.     private final double minSimilarity;
  6.     @Override
  7.     public long[] getUserNeighborhood(long userID) throws TasteException {
  8.         // 计算相似度, 返回TopN
  9.         ......
  10.     }
  11. }
复制代码



基于距离
  1. public final class ThresholdUserNeighborhood extends AbstractUserNeighborhood {
  2.     // 距离门槛
  3.     private final double threshold;
  4.     @Override
  5.     public long[] getUserNeighborhood(long userID) throws TasteException {
  6.         // 遍历用户, 计算相似度, 返回大于门槛值的用户
  7.         ......
  8.     }
  9. }
复制代码



基于内存
  1. public final class CachingUserNeighborhood implements UserNeighborhood {
  2.     // 使用一种计算邻居的实现
  3.     private final UserNeighborhood neighborhood;
  4.     // 存储用户的邻居
  5.     private final Cache<Long, long[]> neighborhoodCache;
  6.     public CachingUserNeighborhood(UserNeighborhood neighborhood, DataModel dataModel) throws TasteException {
  7.         this.neighborhood = neighborhood;
  8.         int maxCacheSize = dataModel.getNumUsers();
  9.         // 使用内部类实现邻居用户的计算
  10.         this.neighborhoodCache = new Cache<Long, long[]>(new NeighborhoodRetriever(neighborhood), maxCacheSize);
  11.     }
  12.     @Override
  13.     public long[] getUserNeighborhood(long userID) throws TasteException {
  14.         return neighborhoodCache.get(userID);
  15.     }
  16.     // 内部类
  17.     // 使用计算邻居的具体实现来实例化
  18.     private static final class NeighborhoodRetriever implements Retriever<Long, long[]> {
  19.         private final UserNeighborhood neighborhood;
  20.         private NeighborhoodRetriever(UserNeighborhood neighborhood) {
  21.             this.neighborhood = neighborhood;
  22.         }
  23.         @Override
  24.         public long[] get(Long key) throws TasteException {
  25.             return neighborhood.getUserNeighborhood(key);
  26.         }
  27.     }
  28. }
复制代码



推荐

基本接口推荐项
  1. // 定义了推荐项的基本属性
  2. public interface RecommendedItem {
  3.     long getItemID();
  4.     float getValue();
  5. }
复制代码



推荐器
  1. // 定义了推荐器应提供的基本方法
  2. public interface Recommender extends Refreshable {
  3.     List<RecommendedItem> recommend(long userID, int howMany) throws TasteException;
  4. }
复制代码


基于用户推荐
  1. // 基于用户推荐的接口
  2. public interface UserBasedRecommender extends Recommender {
  3.     // 需要实现根据用户ID, 指定推荐数目来给用户提供推荐
  4.     long[] mostSimilarUserIDs(long userID, int howMany) throws TasteException;
  5. }
复制代码


基于物品推荐
  1. // 基于物品的推荐
  2. public interface ItemBasedRecommender extends Recommender {
  3.     // 需要实现根据物品ID和指定推荐数量来给出推荐的物品
  4.     List<RecommendedItem> mostSimilarItems(long itemID, int howMany) throws TasteException;
  5. }
复制代码


抽象类
  1. public abstract class AbstractRecommender implements Recommender {
  2.     // 用于存储用户偏好信息的属性
  3.     private final DataModel dataModel;
  4. }
复制代码


具体实现基于用户
  1. public class GenericUserBasedRecommender extends AbstractRecommender implements UserBasedRecommender {
  2.     // 用于计算相似度的属性
  3.     private final UserSimilarity similarity;
  4.     // 用于计算邻居用户的属性
  5.     private final UserNeighborhood neighborhood;
  6.     @Override
  7.     public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException{
  8.         // 获取邻居用户
  9.         long[] theNeighborhood = neighborhood.getUserNeighborhood(userID);
  10.         // 从邻居用户中查找当前用户没有偏好信息的物品ID
  11.         FastIDSet allItemIDs = getAllOtherItems(theNeighborhood, userID);
  12.         // 定义一个对未评论物品计算估值的实现类
  13.         TopItems.Estimator<Long> estimator = new Estimator(userID, theNeighborhood);
  14.         // 对物品估值后返回TopN
  15.         List<RecommendedItem> topItems = TopItems.getTopItems(howMany, allItemIDs.iterator(), rescorer, estimator);
  16.         return topItems;
  17.     }
  18.     // 使用内部类计算估值
  19.     private final class Estimator implements TopItems.Estimator<Long> {
  20.         private final long theUserID;
  21.         private final long[] theNeighborhood;
  22.         Estimator(long theUserID, long[] theNeighborhood) {
  23.             this.theUserID = theUserID;
  24.             this.theNeighborhood = theNeighborhood;
  25.         }
  26.         @Override
  27.         public double estimate(Long itemID) throws TasteException {
  28.             return doEstimatePreference(theUserID, theNeighborhood, itemID);
  29.         }
  30.     }
  31.     // 计算估值的实现方法
  32.     protected float doEstimatePreference(long theUserID, long[] theNeighborhood, long itemID) throws TasteException {
  33.         DataModel dataModel = getDataModel();
  34.         double preference = 0.0;
  35.         double totalSimilarity = 0.0;
  36.         int count = 0;
  37.         // 遍历所有邻居用户
  38.         for (long userID : theNeighborhood) {
  39.             if (userID != theUserID) {
  40.                 Float pref = dataModel.getPreferenceValue(userID, itemID);
  41.                 if (pref != null) {
  42.                     // 计算邻居用户与当前用户的相似度
  43.                     double theSimilarity = similarity.userSimilarity(theUserID, userID);
  44.                     if (!Double.isNaN(theSimilarity)) {
  45.                         // 加权过程实现
  46.                         preference += theSimilarity * pref;
  47.                         totalSimilarity += theSimilarity;
  48.                         count++;
  49.                     }
  50.                 }
  51.             }
  52.         }
  53.         if (count <= 1) {
  54.             return Float.NaN;
  55.         }
  56.         // 加权过程实现
  57.         float estimate = (float) (preference / totalSimilarity);
  58.         if (capper != null) {
  59.             estimate = capper.capEstimate(estimate);
  60.         }
  61.         return estimate;
  62.     }
  63. }
复制代码









引用:http://matrix-lisp.github.io/blo ... out-taste-source-1/


欢迎加入about云群90371779322273151432264021 ,云计算爱好者群,亦可关注about云腾讯认证空间||关注本站微信

没找到任何评论,期待你打破沉寂

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

推荐上一条 /2 下一条